This commit is contained in:
kyx236
2026-03-04 20:25:39 +08:00
738 changed files with 138970 additions and 39525 deletions

View File

@@ -15,7 +15,6 @@ import (
"database/sql"
"encoding/json"
"errors"
"log"
"strconv"
"time"
@@ -25,6 +24,7 @@ import (
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
@@ -50,11 +50,6 @@ type accountRepository struct {
schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
until *time.Time
reason string
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
@@ -127,7 +122,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
}
return nil
}
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs = append(accountIDs, acc.ID)
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outByID[entAcc.ID] = out
}
@@ -388,7 +374,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
account.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
@@ -429,7 +415,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
}
return nil
}
@@ -541,7 +527,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
},
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
}
return nil
}
@@ -576,7 +562,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
}
payload := map[string]any{"last_used": lastUsedPayload}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
}
return nil
}
@@ -591,7 +577,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
@@ -611,11 +597,48 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
account, err := r.GetByID(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
return
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
return
}
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, id := range accountIDs {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return
}
accounts, err := r.GetByIDs(ctx, uniqueIDs)
if err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err)
return
}
for _, account := range accounts {
if account == nil {
continue
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err)
}
}
}
@@ -649,7 +672,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
@@ -666,7 +689,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
@@ -739,7 +762,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
}
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
}
return nil
}
@@ -824,6 +847,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
@@ -847,7 +915,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
}
return nil
}
@@ -894,7 +962,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
@@ -908,7 +976,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
}
return nil
}
@@ -927,7 +995,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
r.syncSchedulerAccountSnapshot(ctx, id)
return nil
@@ -946,7 +1014,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
}
@@ -962,7 +1030,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
}
return nil
}
@@ -986,7 +1054,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
}
return nil
}
@@ -1010,7 +1078,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
@@ -1032,7 +1100,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
// 触发调度器缓存更新(仅当窗口时间有变化时)
if start != nil || end != nil {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
}
}
return nil
@@ -1047,7 +1115,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
if !schedulable {
r.syncSchedulerAccountSnapshot(ctx, id)
@@ -1075,7 +1143,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
}
if rows > 0 {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
}
}
return rows, nil
@@ -1111,7 +1179,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
return nil
}
@@ -1205,7 +1273,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if rows > 0 {
payload := map[string]any{"account_ids": ids}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
shouldSync := false
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
@@ -1215,9 +1283,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
r.syncSchedulerAccountSnapshots(ctx, ids)
}
}
return rows, nil
@@ -1309,10 +1375,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if err != nil {
return nil, err
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -1338,10 +1400,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[acc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outAccounts = append(outAccounts, *out)
}
@@ -1366,48 +1424,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
)
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
return out, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`, pq.Array(accountIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
var until sql.NullTime
var reason sql.NullString
if err := rows.Scan(&id, &until, &reason); err != nil {
return nil, err
}
var untilPtr *time.Time
if until.Valid {
tmp := until.Time
untilPtr = &tmp
}
if reason.Valid {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
} else {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
@@ -1518,31 +1534,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
rateMultiplier := m.RateMultiplier
return &service.Account{
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
TempUnschedulableUntil: m.TempUnschedulableUntil,
TempUnschedulableReason: derefString(m.TempUnschedulableReason),
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
}
}
@@ -1578,3 +1596,64 @@ func joinClauses(clauses []string, sep string) string {
func itoa(v int) string {
return strconv.Itoa(v)
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
//
// FindByExtraField finds accounts by key-value pairs in the extra field.
// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
//
// Use case: Finding Sora accounts linked via linked_openai_account_id.
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
dbaccount.DeletedAtIsNil(),
func(s *entsql.Selector) {
path := sqljson.Path(key)
switch v := value.(type) {
case string:
preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)}
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path))
}
if len(preds) == 1 {
s.Where(preds[0])
} else {
s.Where(entsql.Or(preds...))
}
case int:
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path),
))
case int64:
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path),
))
case json.Number:
if parsed, err := v.Int64(); err == nil {
s.Where(entsql.Or(
sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path),
sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path),
))
} else {
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path))
}
default:
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path))
}
},
).
All(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
return r.accountsToService(ctx, accounts)
}

View File

@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
s.Require().Nil(got.OverloadUntil)
}
func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() {
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"})
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"})
until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second)
reason := `{"rule":"429","matched_keyword":"too many requests"}`
s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason))
gotByID, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().NotNil(gotByID.TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByID.TempUnschedulableReason)
gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID})
s.Require().NoError(err)
s.Require().Len(gotByIDs, 2)
s.Require().Equal(acc2.ID, gotByIDs[0].ID)
s.Require().Nil(gotByIDs[0].TempUnschedulableUntil)
s.Require().Equal("", gotByIDs[0].TempUnschedulableReason)
s.Require().Equal(acc1.ID, gotByIDs[1].ID)
s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason)
s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID))
cleared, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().Nil(cleared.TempUnschedulableUntil)
s.Require().Equal("", cleared.TempUnschedulableReason)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {

View File

@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewAPIKeyRepository(entClient)
apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"database/sql"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -16,10 +17,15 @@ import (
type apiKeyRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client}
func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository {
return newAPIKeyRepositoryWithSQL(client, sqlDB)
}
func newAPIKeyRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *apiKeyRepository {
return &apiKeyRepository{client: client, sql: sqlq}
}
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
@@ -34,9 +40,13 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
SetNillableLastUsedAt(key.LastUsedAt).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetNillableExpiresAt(key.ExpiresAt)
SetNillableExpiresAt(key.ExpiresAt).
SetRateLimit5h(key.RateLimit5h).
SetRateLimit1d(key.RateLimit1d).
SetRateLimit7d(key.RateLimit7d)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
@@ -48,6 +58,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
created, err := builder.Save(ctx)
if err == nil {
key.ID = created.ID
key.LastUsedAt = created.LastUsedAt
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
@@ -116,6 +127,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldQuota,
apikey.FieldQuotaUsed,
apikey.FieldExpiresAt,
apikey.FieldRateLimit5h,
apikey.FieldRateLimit1d,
apikey.FieldRateLimit7d,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
@@ -140,6 +154,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldSoraImagePrice360,
group.FieldSoraImagePrice540,
group.FieldSoraVideoPricePerRequest,
group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
@@ -165,13 +183,20 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
client := clientFromContext(ctx, r.client)
now := time.Now()
builder := r.client.APIKey.Update().
builder := client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetRateLimit5h(key.RateLimit5h).
SetRateLimit1d(key.RateLimit1d).
SetRateLimit7d(key.RateLimit7d).
SetUsage5h(key.Usage5h).
SetUsage1d(key.Usage1d).
SetUsage7d(key.Usage7d).
SetUpdatedAt(now)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
@@ -186,6 +211,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearExpiresAt()
}
// Rate limit window start times
if key.Window5hStart != nil {
builder.SetWindow5hStart(*key.Window5hStart)
} else {
builder.ClearWindow5hStart()
}
if key.Window1dStart != nil {
builder.SetWindow1dStart(*key.Window1dStart)
} else {
builder.ClearWindow1dStart()
}
if key.Window7dStart != nil {
builder.SetWindow7dStart(*key.Window7dStart)
} else {
builder.ClearWindow7dStart()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
@@ -239,9 +281,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
// Apply filters
if filters.Search != "" {
q = q.Where(apikey.Or(
apikey.NameContainsFold(filters.Search),
apikey.KeyContainsFold(filters.Search),
))
}
if filters.Status != "" {
q = q.Where(apikey.StatusEQ(filters.Status))
}
if filters.GroupID != nil {
if *filters.GroupID == 0 {
q = q.Where(apikey.GroupIDIsNil())
} else {
q = q.Where(apikey.GroupIDEQ(*filters.GroupID))
}
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
@@ -375,36 +435,92 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
return keys, nil
}
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
// Use raw SQL for atomic increment to avoid race conditions
// First get current value
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldQuotaUsed).
Only(ctx)
updated, err := r.client.APIKey.UpdateOneID(id).
Where(apikey.DeletedAtIsNil()).
AddQuotaUsed(amount).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
}
return 0, err
}
return updated.QuotaUsed, nil
}
newValue := m.QuotaUsed + amount
// Update with new value
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetQuotaUsed(newValue).
SetLastUsedAt(usedAt).
SetUpdatedAt(usedAt).
Save(ctx)
if err != nil {
return 0, err
return err
}
if affected == 0 {
return 0, service.ErrAPIKeyNotFound
return service.ErrAPIKeyNotFound
}
return nil
}
return newValue, nil
// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes
// window start times via COALESCE if not already set.
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = usage_5h + $1,
usage_1d = usage_1d + $1,
usage_7d = usage_7d + $1,
window_5h_start = COALESCE(window_5h_start, NOW()),
window_1d_start = COALESCE(window_1d_start, NOW()),
window_7d_start = COALESCE(window_7d_start, NOW()),
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL`,
cost, id)
return err
}
// ResetRateLimitWindows resets expired rate limit windows atomically.
func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
return err
}
// GetRateLimitData returns the current rate limit usage and window start times for an API key.
func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (result *service.APIKeyRateLimitData, err error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start
FROM api_keys
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
if !rows.Next() {
return nil, service.ErrAPIKeyNotFound
}
data := &service.APIKeyRateLimitData{}
if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil {
return nil, err
}
return data, rows.Err()
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
@@ -412,19 +528,29 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
return nil
}
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
Quota: m.Quota,
QuotaUsed: m.QuotaUsed,
ExpiresAt: m.ExpiresAt,
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
Quota: m.Quota,
QuotaUsed: m.QuotaUsed,
ExpiresAt: m.ExpiresAt,
RateLimit5h: m.RateLimit5h,
RateLimit1d: m.RateLimit1d,
RateLimit7d: m.RateLimit7d,
Usage5h: m.Usage5h,
Usage1d: m.Usage1d,
Usage7d: m.Usage7d,
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
@@ -440,20 +566,22 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
@@ -477,6 +605,11 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,

View File

@@ -4,11 +4,14 @@ package repository
import (
"context"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -23,7 +26,7 @@ func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
s.repo = newAPIKeyRepositoryWithSQL(s.client, tx)
}
func TestAPIKeyRepoSuite(t *testing.T) {
@@ -155,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
@@ -167,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
@@ -311,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)
@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
return k
}
// --- IncrementQuotaUsed ---
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
user := s.mustCreateUser("incr-basic@test.com")
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
s.Require().NoError(err, "IncrementQuotaUsed")
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
s.Require().NoError(err, "IncrementQuotaUsed second")
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
}
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
user := s.mustCreateUser("incr-deleted@test.com")
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
}
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
// 注意:此测试使用 testEntClient非事务隔离数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
client := testEntClient(t)
repo := NewAPIKeyRepository(client, integrationDB).(*apiKeyRepository)
ctx := context.Background()
// 创建测试用户和 API Key
u, err := client.User.Create().
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
SetPasswordHash("hash").
SetStatus(service.StatusActive).
SetRole(service.RoleUser).
Save(ctx)
require.NoError(t, err, "create user")
k := &service.APIKey{
UserID: u.ID,
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
Name: "Concurrent",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, k), "create api key")
t.Cleanup(func() {
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
})
// 10 个 goroutine 各递增 1.0,总计应为 10.0
const goroutines = 10
const increment = 1.0
var wg sync.WaitGroup
errs := make([]error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
}(i)
}
wg.Wait()
for i, e := range errs {
require.NoError(t, e, "goroutine %d failed", i)
}
// 验证最终结果
got, err := repo.GetByID(ctx, k.ID)
require.NoError(t, err, "GetByID")
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
"并发递增后总和应为 %v实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
}

View File

@@ -0,0 +1,156 @@
package repository
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return &apiKeyRepository{client: client}, client
}
func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User {
t.Helper()
u, err := client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
return userEntityToService(u)
}
func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com")
lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second)
key := &service.APIKey{
UserID: user.ID,
Key: "sk-create-last-used",
Name: "CreateWithLastUsed",
Status: service.StatusActive,
LastUsedAt: &lastUsed,
}
require.NoError(t, repo.Create(ctx, key))
require.NotNil(t, key.LastUsedAt)
require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second)
got, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.NotNil(t, got.LastUsedAt)
require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second)
}
func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used",
Name: "UpdateLastUsed",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
before, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, before.LastUsedAt)
target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second)
require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target))
after, err := repo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.NotNil(t, after.LastUsedAt)
require.WithinDuration(t, target, *after.LastUsedAt, time.Second)
require.WithinDuration(t, target, after.UpdatedAt, time.Second)
}
func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used-deleted",
Name: "UpdateLastUsedDeleted",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
require.NoError(t, repo.Delete(ctx, key.ID))
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
require.ErrorIs(t, err, service.ErrAPIKeyNotFound)
}
func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update-last-used-db-error",
Name: "UpdateLastUsedDBError",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key))
require.NoError(t, client.Close())
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
require.Error(t, err)
}
func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) {
repo, client := newAPIKeyRepoSQLite(t)
ctx := context.Background()
user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com")
first := &service.APIKey{
UserID: user.ID,
Key: "sk-duplicate",
Name: "first",
Status: service.StatusActive,
}
second := &service.APIKey{
UserID: user.ID,
Key: "sk-duplicate",
Name: "second",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, first))
err := repo.Create(ctx, second)
require.ErrorIs(t, err, service.ErrAPIKeyExists)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
@@ -13,11 +14,24 @@ import (
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingRateLimitKeyPrefix = "apikey:rate:"
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
)
// jitteredTTL 返回带随机抖动的 TTL防止缓存雪崩
func jitteredTTL() time.Duration {
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL避免上界预期被打破
if billingCacheJitter <= 0 {
return billingCacheTTL
}
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
return billingCacheTTL - jitter
}
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
@@ -37,6 +51,20 @@ const (
subFieldVersion = "version"
)
// billingRateLimitKey generates the Redis key for API key rate limit cache.
func billingRateLimitKey(keyID int64) string {
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
}
const (
rateLimitFieldUsage5h = "usage_5h"
rateLimitFieldUsage1d = "usage_1d"
rateLimitFieldUsage7d = "usage_7d"
rateLimitFieldWindow5h = "window_5h"
rateLimitFieldWindow1d = "window_1d"
rateLimitFieldWindow7d = "window_7d"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
@@ -61,6 +89,21 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
updateRateLimitUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
@@ -82,14 +125,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
return err
}
return nil
}
@@ -163,16 +207,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
pipe.Expire(ctx, key, jitteredTTL())
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
}
@@ -181,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID,
key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
key := billingRateLimitKey(keyID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
data := &service.APIKeyRateLimitCacheData{}
if v, ok := result[rateLimitFieldUsage5h]; ok {
data.Usage5h, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage1d]; ok {
data.Usage1d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage7d]; ok {
data.Usage7d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldWindow5h]; ok {
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow1d]; ok {
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow7d]; ok {
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
}
return data, nil
}
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
if data == nil {
return nil
}
key := billingRateLimitKey(keyID)
fields := map[string]any{
rateLimitFieldUsage5h: data.Usage5h,
rateLimitFieldUsage1d: data.Usage1d,
rateLimitFieldUsage7d: data.Usage7d,
rateLimitFieldWindow5h: data.Window5h,
rateLimitFieldWindow1d: data.Window1d,
rateLimitFieldWindow7d: data.Window7d,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, rateLimitCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
key := billingRateLimitKey(keyID)
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
return err
}
return nil
}
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
key := billingRateLimitKey(keyID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}
}
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播key 不存在redis.Nil应返回 nil。
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
tests := []struct {
name string
fn func(ctx context.Context, cache service.BillingCache)
expectErr bool
}{
{
name: "key_not_exists_returns_nil",
fn: func(ctx context.Context, cache service.BillingCache) {
// key 不存在时Lua 脚本返回 0redis.Nil应返回 nil 而非错误
err := cache.DeductUserBalance(ctx, 99999, 1.0)
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
},
},
{
name: "existing_key_deducts_successfully",
fn: func(ctx context.Context, cache service.BillingCache) {
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
err := cache.DeductUserBalance(ctx, 200, 10.0)
require.NoError(s.T(), err, "DeductUserBalance should succeed")
bal, err := cache.GetUserBalance(ctx, 200)
require.NoError(s.T(), err)
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
},
},
{
name: "cancelled_context_propagates_error",
fn: func(ctx context.Context, cache service.BillingCache) {
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
cancelCtx, cancel := context.WithCancel(ctx)
cancel() // 立即取消
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
require.Error(s.T(), err, "cancelled context should propagate error")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, cache)
})
}
}
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播key 不存在redis.Nil应返回 nil。
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
s.Run("key_not_exists_returns_nil", func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
})
s.Run("cancelled_context_propagates_error", func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
require.Error(s.T(), err, "cancelled context should propagate error")
})
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}

View File

@@ -0,0 +1,82 @@
package repository
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 ---
func TestJitteredTTL_WithinExpectedRange(t *testing.T) {
// jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter)
// 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内
lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s
upperBound := billingCacheTTL // 5min
for i := 0; i < 200; i++ {
ttl := jitteredTTL()
assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound),
"TTL 不应低于 %v实际得到 %v", lowerBound, ttl)
assert.LessOrEqual(t, int64(ttl), int64(upperBound),
"TTL 不应超过 %v上界不变保证实际得到 %v", upperBound, ttl)
}
}
func TestJitteredTTL_NeverExceedsBase(t *testing.T) {
// 关键安全性测试jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL
for i := 0; i < 500; i++ {
ttl := jitteredTTL()
assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL),
"jitteredTTL 不应超过基础 TTL上界预期不被打破")
}
}
func TestJitteredTTL_HasVariance(t *testing.T) {
// 验证抖动确实产生了不同的值
results := make(map[time.Duration]bool)
for i := 0; i < 100; i++ {
ttl := jitteredTTL()
results[ttl] = true
}
require.Greater(t, len(results), 1,
"jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同")
}
func TestJitteredTTL_AverageNearCenter(t *testing.T) {
// 验证平均值大约在抖动范围中间
var sum time.Duration
runs := 1000
for i := 0; i < runs; i++ {
sum += jitteredTTL()
}
avg := sum / time.Duration(runs)
expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s
// 允许 ±5s 的误差
tolerance := 5 * time.Second
assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance),
"平均 TTL 应接近抖动范围中心 %v", expectedCenter)
}
func TestBillingKeyGeneration(t *testing.T) {
t.Run("balance_key", func(t *testing.T) {
key := billingBalanceKey(12345)
assert.Equal(t, "billing:balance:12345", key)
})
t.Run("sub_key", func(t *testing.T) {
key := billingSubKey(100, 200)
assert.Equal(t, "billing:sub:100:200", key)
})
}
func BenchmarkJitteredTTL(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = jitteredTTL()
}
}

View File

@@ -5,6 +5,7 @@ package repository
import (
"math"
"testing"
"time"
"github.com/stretchr/testify/require"
)
@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
})
}
}
func TestJitteredTTL(t *testing.T) {
const (
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
)
for i := 0; i < 200; i++ {
ttl := jitteredTTL()
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
}
}
func TestJitteredTTL_HasVariation(t *testing.T) {
// 多次调用应该产生不同的值(验证抖动存在)
seen := make(map[time.Duration]struct{}, 50)
for i := 0; i < 50; i++ {
seen[jitteredTTL()] = struct{}{}
}
// 50 次调用中应该至少有 2 个不同的值
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
}

View File

@@ -4,13 +4,14 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
@@ -28,11 +29,14 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient {
type claudeOAuthService struct {
baseURL string
tokenURL string
clientFactory func(proxyURL string) *req.Client
clientFactory func(proxyURL string) (*req.Client, error)
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return "", fmt.Errorf("create HTTP client: %w", err)
}
var orgs []struct {
UUID string `json:"uuid"`
@@ -41,7 +45,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
@@ -53,11 +57,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
@@ -69,26 +73,29 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
// 如果只有一个组织,直接使用
if len(orgs) == 1 {
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
for _, org := range orgs {
if org.RavenType != nil && *org.RavenType == "team" {
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
org.UUID, org.Name, *org.RavenType)
return org.UUID, nil
}
}
// 如果没有 team 类型的组织,使用第一个
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return "", fmt.Errorf("create HTTP client: %w", err)
}
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
@@ -103,9 +110,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
"code_challenge_method": "S256",
}
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
RedirectURI string `json:"redirect_uri"`
@@ -128,11 +135,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
@@ -160,12 +167,15 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
// Parse code which may contain state in format "authCode#state"
authCode := code
@@ -192,9 +202,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
@@ -208,22 +218,25 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
Post(s.tokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
reqBody := map[string]any{
"grant_type": "refresh_token",
@@ -253,16 +266,20 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
func createReqClient(proxyURL string) *req.Client {
func createReqClient(proxyURL string) (*req.Client, error) {
// 禁用 CookieJar确保每次授权都是干净的会话
client := req.C().
SetTimeout(60 * time.Second).
ImpersonateChrome().
SetCookieJar(nil) // 禁用 CookieJar
if strings.TrimSpace(proxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(proxyURL))
trimmed, _, err := proxyurl.Parse(proxyURL)
if err != nil {
return nil, err
}
if trimmed != "" {
client.SetProxyURL(trimmed)
}
return client
return client, nil
}

View File

@@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
@@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
@@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
@@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
resp, err := s.client.RefreshToken(context.Background(), "rt", "")

View File

@@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
return nil, fmt.Errorf("create http client failed: %w", err)
}
resp, err = client.Do(req)

View File

@@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
@@ -112,6 +112,17 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
require.Error(s.T(), err, "expected error for cancelled context")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() {
s.fetcher = &claudeUsageService{
usageURL: "http://example.com",
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "create http client failed")
}
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}

View File

@@ -147,100 +147,6 @@ var (
return 1
`)
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local slotTTL = tonumber(ARGV[1])
-- Get current server time
local timeResult = redis.call('TIME')
local nowSeconds = tonumber(timeResult[1])
local cutoffTime = nowSeconds - slotTTL
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
-- Clean up expired slots before counting
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// getUsersLoadBatchScript - batch load query for users with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ...
getUsersLoadBatchScript = redis.NewScript(`
local result = {}
local slotTTL = tonumber(ARGV[1])
-- Get current server time
local timeResult = redis.call('TIME')
local nowSeconds = tonumber(timeResult[1])
local cutoffTime = nowSeconds - slotTTL
local i = 2
while i <= #ARGV do
local userID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:user:' .. userID
-- Clean up expired slots before counting
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'concurrency:wait:' .. userID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, userID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
@@ -321,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
return result, nil
}
func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return nil, fmt.Errorf("redis TIME: %w", err)
}
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
pipe := c.rdb.Pipeline()
type accountCmd struct {
accountID int64
zcardCmd *redis.IntCmd
}
cmds := make([]accountCmd, 0, len(accountIDs))
for _, accountID := range accountIDs {
slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10)
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
cmds = append(cmds, accountCmd{
accountID: accountID,
zcardCmd: pipe.ZCard(ctx, slotKey),
})
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("pipeline exec: %w", err)
}
result := make(map[int64]int, len(accountIDs))
for _, cmd := range cmds {
result[cmd.accountID] = int(cmd.zcardCmd.Val())
}
return result, nil
}
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
@@ -399,29 +342,53 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis ClusterLua 内动态拼 key 会 CROSSSLOT
// 每个账号执行 3 个命令ZREMRANGEBYSCORE清理过期、ZCARD并发数、GET等待数
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return nil, err
return nil, fmt.Errorf("redis TIME: %w", err)
}
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
pipe := c.rdb.Pipeline()
type accountCmds struct {
id int64
maxConcurrency int
zcardCmd *redis.IntCmd
getCmd *redis.StringCmd
}
cmds := make([]accountCmds, 0, len(accounts))
for _, acc := range accounts {
slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10)
waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10)
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
ac := accountCmds{
id: acc.ID,
maxConcurrency: acc.MaxConcurrency,
zcardCmd: pipe.ZCard(ctx, slotKey),
getCmd: pipe.Get(ctx, waitKey),
}
cmds = append(cmds, ac)
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("pipeline exec: %w", err)
}
loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts))
for _, ac := range cmds {
currentConcurrency := int(ac.zcardCmd.Val())
waitingCount := 0
if v, err := ac.getCmd.Int(); err == nil {
waitingCount = v
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
loadRate := 0
if ac.maxConcurrency > 0 {
loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency
}
loadMap[ac.id] = &service.AccountLoadInfo{
AccountID: ac.id,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
@@ -436,29 +403,52 @@ func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []servic
return map[int64]*service.UserLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, u := range users {
args = append(args, u.ID, u.MaxConcurrency)
}
result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return nil, err
return nil, fmt.Errorf("redis TIME: %w", err)
}
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
pipe := c.rdb.Pipeline()
type userCmds struct {
id int64
maxConcurrency int
zcardCmd *redis.IntCmd
getCmd *redis.StringCmd
}
cmds := make([]userCmds, 0, len(users))
for _, u := range users {
slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10)
waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10)
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
uc := userCmds{
id: u.ID,
maxConcurrency: u.MaxConcurrency,
zcardCmd: pipe.ZCard(ctx, slotKey),
getCmd: pipe.Get(ctx, waitKey),
}
cmds = append(cmds, uc)
}
loadMap := make(map[int64]*service.UserLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("pipeline exec: %w", err)
}
loadMap := make(map[int64]*service.UserLoadInfo, len(users))
for _, uc := range cmds {
currentConcurrency := int(uc.zcardCmd.Val())
waitingCount := 0
if v, err := uc.getCmd.Int(); err == nil {
waitingCount = v
}
userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[userID] = &service.UserLoadInfo{
UserID: userID,
loadRate := 0
if uc.maxConcurrency > 0 {
loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency
}
loadMap[uc.id] = &service.UserLoadInfo{
UserID: uc.id,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,

View File

@@ -5,6 +5,7 @@ package repository
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/ent"
@@ -66,6 +67,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
// 启动阶段:从配置或数据库中确保系统密钥可用。
if err := ensureBootstrapSecrets(migrationCtx, client, cfg); err != nil {
_ = client.Close()
return nil, nil, err
}
// 在密钥补齐后执行完整配置校验,避免空 jwt.secret 导致服务运行时失败。
if err := cfg.Validate(); err != nil {
_ = client.Close()
return nil, nil, fmt.Errorf("validate config after secret bootstrap: %w", err)
}
// SIMPLE 模式:启动时补齐各平台默认分组。
// - anthropic/openai/gemini: 确保存在 <platform>-default
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)

View File

@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}

View File

@@ -0,0 +1,9 @@
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations.
// Returned as geminicli.DriveClient interface for DI (Strategy A).
func NewGeminiDriveClient() geminicli.DriveClient {
return geminicli.NewDriveClient()
}

View File

@@ -26,7 +26,10 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
}
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
client, err := createGeminiReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
@@ -72,7 +75,10 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
}
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
client, err := createGeminiReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
@@ -111,7 +117,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
return &tokenResp, nil
}
func createGeminiReqClient(proxyURL string) *req.Client {
func createGeminiReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,

View File

@@ -26,7 +26,11 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
}
var out geminicli.LoadCodeAssistResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
client, err := createGeminiCliReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
@@ -66,7 +70,11 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
var out geminicli.OnboardUserResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
client, err := createGeminiCliReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
@@ -98,7 +106,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return &out, nil
}
func createGeminiCliReqClient(proxyURL string) *req.Client {
func createGeminiCliReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,

View File

@@ -5,8 +5,10 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
@@ -18,14 +20,27 @@ type githubReleaseClient struct {
downloadHTTPClient *http.Client
}
type githubReleaseClientError struct {
err error
}
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub支持 http/https/socks5/socks5h 协议
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
// 代理配置失败时行为由 allowDirectOnProxyError 控制:
// - false默认返回错误占位客户端禁止回退到直连
// - true回退到直连仅限管理员显式开启
func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient {
// 安全说明httpclient.GetClient 的错误链url.Parse / proxyutil不含明文代理凭据
// 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err)
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
@@ -35,6 +50,10 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
ProxyURL: proxyURL,
})
if err != nil {
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err)
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
@@ -44,6 +63,18 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
}
}
func (c *githubReleaseClientError) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
return nil, c.err
}
func (c *githubReleaseClientError) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
return c.err
}
func (c *githubReleaseClientError) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
return nil, c.err
}
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)

View File

@@ -4,11 +4,13 @@ import (
"context"
"database/sql"
"errors"
"log"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
@@ -47,12 +49,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -68,7 +75,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
groupIn.CreatedAt = created.CreatedAt
groupIn.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
}
}
return translatePersistenceError(err, nil, service.ErrGroupExists)
@@ -110,10 +117,47 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {
builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD)
} else {
builder = builder.ClearDailyLimitUsd()
}
if groupIn.WeeklyLimitUSD != nil {
builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD)
} else {
builder = builder.ClearWeeklyLimitUsd()
}
if groupIn.MonthlyLimitUSD != nil {
builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD)
} else {
builder = builder.ClearMonthlyLimitUsd()
}
if groupIn.ImagePrice1K != nil {
builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K)
} else {
builder = builder.ClearImagePrice1k()
}
if groupIn.ImagePrice2K != nil {
builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K)
} else {
builder = builder.ClearImagePrice2k()
}
if groupIn.ImagePrice4K != nil {
builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K)
} else {
builder = builder.ClearImagePrice4k()
}
// 处理 FallbackGroupIDnil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
@@ -144,7 +188,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
}
groupIn.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
}
return nil
}
@@ -155,7 +199,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
}
return nil
}
@@ -183,7 +227,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
q = q.Where(group.IsExclusiveEQ(*isExclusive))
}
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
@@ -273,6 +317,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
}
// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
// 返回结构map[groupID]exists。
func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) {
result := make(map[int64]bool, len(ids))
if len(ids) == 0 {
return result, nil
}
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
result[id] = false
}
if len(uniqueIDs) == 0 {
return result, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id
FROM groups
WHERE id = ANY($1) AND deleted_at IS NULL
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
result[id] = true
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
@@ -288,7 +380,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
}
affected, _ := res.RowsAffected()
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
}
return affected, nil
}
@@ -398,7 +490,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
}
return affectedUserIDs, nil
@@ -492,7 +584,7 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
// 发送调度器事件
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
}
return nil
@@ -504,22 +596,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
return nil
}
// 使用事务批量更新
tx, err := r.client.Tx(ctx)
// 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
sortOrderByID := make(map[int64]int, len(updates))
groupIDs := make([]int64, 0, len(updates))
for _, u := range updates {
if u.ID <= 0 {
continue
}
if _, exists := sortOrderByID[u.ID]; !exists {
groupIDs = append(groupIDs, u.ID)
}
sortOrderByID[u.ID] = u.SortOrder
}
if len(groupIDs) == 0 {
return nil
}
// 与旧实现保持一致:任何不存在/已删除的分组都返回 not found且不执行更新。
var existingCount int
if err := scanSingleRow(
ctx,
r.sql,
`SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`,
[]any{pq.Array(groupIDs)},
&existingCount,
); err != nil {
return err
}
if existingCount != len(groupIDs) {
return service.ErrGroupNotFound
}
args := make([]any, 0, len(groupIDs)*2+1)
caseClauses := make([]string, 0, len(groupIDs))
placeholder := 1
for _, id := range groupIDs {
caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1))
args = append(args, id, sortOrderByID[id])
placeholder += 2
}
args = append(args, pq.Array(groupIDs))
query := fmt.Sprintf(`
UPDATE groups
SET sort_order = CASE id
%s
ELSE sort_order
END
WHERE deleted_at IS NULL AND id = ANY($%d)
`, strings.Join(caseClauses, "\n\t\t\t"), placeholder)
result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, u := range updates {
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
}
if err := tx.Commit(); err != nil {
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected != int64(len(groupIDs)) {
return service.ErrGroupNotFound
}
for _, id := range groupIDs {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err)
}
}
return nil
}

View File

@@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() {
})
}
func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() {
g1 := &service.Group{
Name: "sort-g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g2 := &service.Group{
Name: "sort-g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g3 := &service.Group{
Name: "sort-g3",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
s.Require().NoError(s.repo.Create(s.ctx, g3))
err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
{ID: g1.ID, SortOrder: 30},
{ID: g2.ID, SortOrder: 10},
{ID: g3.ID, SortOrder: 20},
{ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准
})
s.Require().NoError(err)
got1, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
got2, err := s.repo.GetByID(s.ctx, g2.ID)
s.Require().NoError(err)
got3, err := s.repo.GetByID(s.ctx, g3.ID)
s.Require().NoError(err)
s.Require().Equal(30, got1.SortOrder)
s.Require().Equal(15, got2.SortOrder)
s.Require().Equal(20, got3.SortOrder)
}
func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() {
g1 := &service.Group{
Name: "sort-no-partial",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
before, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
beforeSort := before.SortOrder
err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
{ID: g1.ID, SortOrder: 99},
{ID: 99999999, SortOrder: 1},
})
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
after, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
s.Require().Equal(beforeSort, after.SortOrder)
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{
Name: "g1",

View File

@@ -14,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -235,7 +236,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
isolation := s.getIsolationMode()
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
if err != nil {
return nil, err
}
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
@@ -373,9 +377,8 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
return entry
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
}
// getClientEntry 获取或创建客户端条目
@@ -385,7 +388,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
if err != nil {
return nil, err
}
// 构建缓存键(根据隔离策略不同)
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
// 构建连接池配置键(用于检测配置变更)
@@ -680,17 +686,18 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string {
// - raw: 原始代理 URL 字符串
//
// 返回:
// - string: 标准化的代理键(空或解析失败返回 "direct"
// - *url.URL: 解析后的 URL或解析失败返回 nil
func normalizeProxyURL(raw string) (string, *url.URL) {
proxyURL := strings.TrimSpace(raw)
if proxyURL == "" {
return directProxyKey, nil
}
parsed, err := url.Parse(proxyURL)
// - string: 标准化的代理键(空返回 "direct"
// - *url.URL: 解析后的 URL空返回 nil
// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连)
func normalizeProxyURL(raw string) (string, *url.URL, error) {
_, parsed, err := proxyurl.Parse(raw)
if err != nil {
return directProxyKey, nil
return "", nil, err
}
if parsed == nil {
return directProxyKey, nil, nil
}
// 规范化:小写 scheme/host去除路径和查询参数
parsed.Scheme = strings.ToLower(parsed.Scheme)
parsed.Host = strings.ToLower(parsed.Host)
parsed.Path = ""
@@ -710,7 +717,7 @@ func normalizeProxyURL(raw string) (string, *url.URL) {
parsed.Host = hostname
}
}
return parsed.String(), parsed
return parsed.String(), parsed, nil
}
// defaultPoolSettings 获取默认连接池配置

View File

@@ -59,7 +59,10 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
// 模拟优化后的行为,从缓存获取客户端
b.Run("复用", func(b *testing.B) {
// 预热:确保客户端已缓存
entry := svc.getOrCreateClient(proxyURL, 1, 1)
entry, err := svc.getOrCreateClient(proxyURL, 1, 1)
if err != nil {
b.Fatalf("getOrCreateClient: %v", err)
}
client := entry.client
b.ResetTimer() // 重置计时器,排除预热时间
for i := 0; i < b.N; i++ {

View File

@@ -44,7 +44,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
// 验证未配置时使用 300 秒默认值
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
@@ -55,25 +55,27 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
// 验证解析失败时回退到直连模式
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() {
// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误
// 验证解析失败时拒绝回退到直连模式
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() {
svc := s.newService()
entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1)
require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
_, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false)
require.Error(s.T(), err, "expected error for invalid proxy URL")
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
key1, _ := normalizeProxyURL("http://proxy.local:8080")
key2, _ := normalizeProxyURL("http://proxy.local:8080/")
key1, _, err1 := normalizeProxyURL("http://proxy.local:8080")
require.NoError(s.T(), err1)
key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/")
require.NoError(s.T(), err2)
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
}
@@ -171,8 +173,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一代理,不同账户
entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3)
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3)
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3)
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
}
@@ -183,8 +185,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
svc := s.newService()
// 同一账户,不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3)
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
}
@@ -195,8 +197,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一账户,先后使用不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3)
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
@@ -208,7 +210,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 账户并发数为 12
entry := svc.getOrCreateClient("", 1, 12)
entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
// 连接池参数应与并发数一致
@@ -228,7 +230,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
}
svc := s.newService()
// 账户并发数为 0应使用全局配置
entry := svc.getOrCreateClient("", 1, 0)
entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
@@ -245,12 +247,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
}
svc := s.newService()
// 创建两个客户端,设置不同的最后使用时间
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1)
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1)
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1)
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
// 创建第三个客户端,触发淘汰
_ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1)
_ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1)
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
@@ -264,12 +266,12 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
}
svc := s.newService()
entry1 := svc.getOrCreateClient("", 1, 1)
entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1)
// 设置为很久之前使用,但有活跃请求
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
// 创建新客户端,触发淘汰检查
_ = svc.getOrCreateClient("", 2, 1)
_, _ = svc.getOrCreateClient("", 2, 1)
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
}
@@ -279,6 +281,14 @@ func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
}
// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误
func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry {
t.Helper()
entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency)
require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency)
return entry
}
// hasEntry 检查客户端是否存在于缓存中
// 辅助函数,用于验证淘汰逻辑
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {

View File

@@ -0,0 +1,237 @@
package repository
import (
"context"
"database/sql"
"errors"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type idempotencyRepository struct {
sql sqlExecutor
}
func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository {
return &idempotencyRepository{sql: sqlDB}
}
func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) {
if record == nil {
return false, nil
}
query := `
INSERT INTO idempotency_records (
scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at
) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (scope, idempotency_key_hash) DO NOTHING
RETURNING id, created_at, updated_at
`
var createdAt time.Time
var updatedAt time.Time
err := scanSingleRow(ctx, r.sql, query, []any{
record.Scope,
record.IdempotencyKeyHash,
record.RequestFingerprint,
record.Status,
record.LockedUntil,
record.ExpiresAt,
}, &record.ID, &createdAt, &updatedAt)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
record.CreatedAt = createdAt
record.UpdatedAt = updatedAt
return true, nil
}
func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
query := `
SELECT
id, scope, idempotency_key_hash, request_fingerprint, status, response_status,
response_body, error_reason, locked_until, expires_at, created_at, updated_at
FROM idempotency_records
WHERE scope = $1 AND idempotency_key_hash = $2
`
record := &service.IdempotencyRecord{}
var responseStatus sql.NullInt64
var responseBody sql.NullString
var errorReason sql.NullString
var lockedUntil sql.NullTime
err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash},
&record.ID,
&record.Scope,
&record.IdempotencyKeyHash,
&record.RequestFingerprint,
&record.Status,
&responseStatus,
&responseBody,
&errorReason,
&lockedUntil,
&record.ExpiresAt,
&record.CreatedAt,
&record.UpdatedAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
if responseStatus.Valid {
v := int(responseStatus.Int64)
record.ResponseStatus = &v
}
if responseBody.Valid {
v := responseBody.String
record.ResponseBody = &v
}
if errorReason.Valid {
v := errorReason.String
record.ErrorReason = &v
}
if lockedUntil.Valid {
v := lockedUntil.Time
record.LockedUntil = &v
}
return record, nil
}
func (r *idempotencyRepository) TryReclaim(
ctx context.Context,
id int64,
fromStatus string,
now, newLockedUntil, newExpiresAt time.Time,
) (bool, error) {
query := `
UPDATE idempotency_records
SET status = $2,
locked_until = $3,
error_reason = NULL,
updated_at = NOW(),
expires_at = $4
WHERE id = $1
AND status = $5
AND (locked_until IS NULL OR locked_until <= $6)
`
res, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusProcessing,
newLockedUntil,
newExpiresAt,
fromStatus,
now,
)
if err != nil {
return false, err
}
affected, err := res.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *idempotencyRepository) ExtendProcessingLock(
ctx context.Context,
id int64,
requestFingerprint string,
newLockedUntil,
newExpiresAt time.Time,
) (bool, error) {
query := `
UPDATE idempotency_records
SET locked_until = $2,
expires_at = $3,
updated_at = NOW()
WHERE id = $1
AND status = $4
AND request_fingerprint = $5
`
res, err := r.sql.ExecContext(
ctx,
query,
id,
newLockedUntil,
newExpiresAt,
service.IdempotencyStatusProcessing,
requestFingerprint,
)
if err != nil {
return false, err
}
affected, err := res.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
query := `
UPDATE idempotency_records
SET status = $2,
response_status = $3,
response_body = $4,
error_reason = NULL,
locked_until = NULL,
expires_at = $5,
updated_at = NOW()
WHERE id = $1
`
_, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusSucceeded,
responseStatus,
responseBody,
expiresAt,
)
return err
}
func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
query := `
UPDATE idempotency_records
SET status = $2,
error_reason = $3,
locked_until = $4,
expires_at = $5,
updated_at = NOW()
WHERE id = $1
`
_, err := r.sql.ExecContext(ctx, query,
id,
service.IdempotencyStatusFailedRetryable,
errorReason,
lockedUntil,
expiresAt,
)
return err
}
func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) {
if limit <= 0 {
limit = 500
}
query := `
WITH victims AS (
SELECT id
FROM idempotency_records
WHERE expires_at <= $1
ORDER BY expires_at ASC
LIMIT $2
)
DELETE FROM idempotency_records
WHERE id IN (SELECT id FROM victims)
`
res, err := r.sql.ExecContext(ctx, query, now, limit)
if err != nil {
return 0, err
}
return res.RowsAffected()
}

View File

@@ -0,0 +1,149 @@
//go:build integration
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns.
func hashedTestValue(t *testing.T, prefix string) string {
t.Helper()
sum := sha256.Sum256([]byte(uniqueTestValue(t, prefix)))
return hex.EncodeToString(sum[:])
}
func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-create"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash"),
RequestFingerprint: hashedTestValue(t, "idem-fp"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(30 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NotZero(t, record.ID)
duplicate := &service.IdempotencyRecord{
Scope: record.Scope,
IdempotencyKeyHash: record.IdempotencyKeyHash,
RequestFingerprint: hashedTestValue(t, "idem-fp-other"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(30 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err = repo.CreateProcessing(ctx, duplicate)
require.NoError(t, err)
require.False(t, owner, "same scope+key hash should be de-duplicated")
}
func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-reclaim"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-reclaim"),
RequestFingerprint: hashedTestValue(t, "idem-fp-reclaim"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(10 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NoError(t, repo.MarkFailedRetryable(
ctx,
record.ID,
"RETRYABLE_FAILURE",
now.Add(-2*time.Second),
now.Add(24*time.Hour),
))
newLockedUntil := now.Add(20 * time.Second)
reclaimed, err := repo.TryReclaim(
ctx,
record.ID,
service.IdempotencyStatusFailedRetryable,
now,
newLockedUntil,
now.Add(24*time.Hour),
)
require.NoError(t, err)
require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim")
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, service.IdempotencyStatusProcessing, got.Status)
require.NotNil(t, got.LockedUntil)
require.True(t, got.LockedUntil.After(now))
require.NoError(t, repo.MarkFailedRetryable(
ctx,
record.ID,
"RETRYABLE_FAILURE",
now.Add(20*time.Second),
now.Add(24*time.Hour),
))
reclaimed, err = repo.TryReclaim(
ctx,
record.ID,
service.IdempotencyStatusFailedRetryable,
now,
now.Add(40*time.Second),
now.Add(24*time.Hour),
)
require.NoError(t, err)
require.False(t, reclaimed, "within lock window should not reclaim")
}
func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
tx := testTx(t)
repo := &idempotencyRepository{sql: tx}
ctx := context.Background()
now := time.Now().UTC()
record := &service.IdempotencyRecord{
Scope: uniqueTestValue(t, "idem-scope-success"),
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-success"),
RequestFingerprint: hashedTestValue(t, "idem-fp-success"),
Status: service.IdempotencyStatusProcessing,
LockedUntil: ptrTime(now.Add(10 * time.Second)),
ExpiresAt: now.Add(24 * time.Hour),
}
owner, err := repo.CreateProcessing(ctx, record)
require.NoError(t, err)
require.True(t, owner)
require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour)))
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, service.IdempotencyStatusSucceeded, got.Status)
require.NotNil(t, got.ResponseStatus)
require.Equal(t, 200, *got.ResponseStatus)
require.NotNil(t, got.ResponseBody)
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
require.Nil(t, got.LockedUntil)
}

View File

@@ -12,7 +12,7 @@ import (
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
fingerprintTTL = 7 * 24 * time.Hour // 7天配合每24小时懒续期可保持活跃账号永不过期
maskedSessionKeyPrefix = "masked_session:"
maskedSessionTTL = 15 * time.Minute
)

View File

@@ -50,6 +50,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
type migrationChecksumCompatibilityRule struct {
fileChecksum string
acceptedDBChecksum map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行避免放宽全局校验。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
"054_drop_legacy_cache_columns.sql": {
fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
acceptedDBChecksum: map[string]struct{}{
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
},
},
"061_add_usage_log_request_type.sql": {
fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
acceptedDBChecksum: map[string]struct{}{
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
},
},
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
@@ -147,6 +171,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
// 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。
if isMigrationChecksumCompatible(name, existing, checksum) {
continue
}
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf(
@@ -165,8 +193,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
nonTx, err := validateMigrationExecutionMode(name, content)
if err != nil {
return fmt.Errorf("validate migration %s: %w", name, err)
}
if nonTx {
// *_notx.sql用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
for i, stmt := range statements {
trimmed := strings.TrimSpace(stmt)
if trimmed == "" {
continue
}
if stripSQLLineComment(trimmed) == "" {
continue
}
if _, err := db.ExecContext(ctx, trimmed); err != nil {
return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err)
}
}
if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
return fmt.Errorf("record migration %s (non-tx): %w", name, err)
}
continue
}
// 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
@@ -268,6 +322,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
if rule.fileChecksum != fileChecksum {
return false
}
_, ok = rule.acceptedDBChecksum[dbChecksum]
return ok
}
func validateMigrationExecutionMode(name, content string) (bool, error) {
normalizedName := strings.ToLower(strings.TrimSpace(name))
upperContent := strings.ToUpper(content)
nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix)
if !nonTx {
if strings.Contains(upperContent, "CONCURRENTLY") {
return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations")
}
return false, nil
}
if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") {
return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)")
}
statements := splitSQLStatements(content)
for _, stmt := range statements {
normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt)))
if normalizedStmt == "" {
continue
}
if strings.Contains(normalizedStmt, "CONCURRENTLY") {
isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX")
isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX")
if !isCreateIndex && !isDropIndex {
return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements")
}
if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") {
return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency")
}
if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") {
return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency")
}
continue
}
return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements")
}
return true, nil
}
func splitSQLStatements(content string) []string {
parts := strings.Split(content, ";")
out := make([]string, 0, len(parts))
for _, part := range parts {
if strings.TrimSpace(part) == "" {
continue
}
out = append(out, part)
}
return out
}
func stripSQLLineComment(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
if idx := strings.Index(line, "--"); idx >= 0 {
lines[i] = line[:idx]
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。

View File

@@ -0,0 +1,54 @@
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsMigrationChecksumCompatible(t *testing.T) {
t.Run("054历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"054_drop_legacy_cache_columns.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
)
require.True(t, ok)
})
t.Run("054在未知文件checksum下不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"054_drop_legacy_cache_columns.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"0000000000000000000000000000000000000000000000000000000000000000",
)
require.False(t, ok)
})
t.Run("061历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"061_add_usage_log_request_type.sql",
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0",
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
)
require.True(t, ok)
})
t.Run("061第二个历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"061_add_usage_log_request_type.sql",
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3",
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
)
require.True(t, ok)
})
t.Run("非白名单迁移不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"001_init.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
)
require.False(t, ok)
})
}

View File

@@ -0,0 +1,368 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io/fs"
"strings"
"testing"
"testing/fstest"
"time"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestApplyMigrations_NilDB(t *testing.T) {
err := ApplyMigrations(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "nil sql db")
}
func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnError(errors.New("lock failed"))
err = ApplyMigrations(context.Background(), db)
require.Error(t, err)
require.Contains(t, err.Error(), "acquire migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestLatestMigrationBaseline(t *testing.T) {
t.Run("empty_fs_returns_baseline", func(t *testing.T) {
version, description, hash, err := latestMigrationBaseline(fstest.MapFS{})
require.NoError(t, err)
require.Equal(t, "baseline", version)
require.Equal(t, "baseline", description)
require.Equal(t, "", hash)
})
t.Run("uses_latest_sorted_sql_file", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
"010_final.sql": &fstest.MapFile{
Data: []byte("CREATE TABLE t2(id int);"),
},
}
version, description, hash, err := latestMigrationBaseline(fsys)
require.NoError(t, err)
require.Equal(t, "010_final", version)
require.Equal(t, "010_final", description)
require.Len(t, hash, 64)
})
t.Run("read_file_error", func(t *testing.T) {
fsys := fstest.MapFS{
"010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
}
_, _, _, err := latestMigrationBaseline(fsys)
require.Error(t, err)
})
}
func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file"))
var (
name string
rule migrationChecksumCompatibilityRule
)
for n, r := range migrationChecksumCompatibilityRules {
name = n
rule = r
break
}
require.NotEmpty(t, name)
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match"))
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum))
var accepted string
for checksum := range rule.acceptedDBChecksum {
accepted = checksum
break
}
require.NotEmpty(t, accepted)
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
}
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
"002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")},
}
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_checking_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnError(errors.New("exists failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "check schema_migrations")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_counting_atlas_rows", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnError(errors.New("count failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "count atlas_schema_revisions")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_creating_atlas_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
WillReturnError(errors.New("create failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "create atlas_schema_revisions")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_inserting_baseline", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()).
WillReturnError(errors.New("insert failed"))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
}
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "insert atlas baseline")
require.NoError(t, mock.ExpectationsWereMet())
})
}
func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_init.sql").
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum"))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "checksum mismatch")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_err.sql").
WillReturnError(errors.New("query failed"))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "check migration 001_err.sql")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
alreadySQL := "CREATE TABLE t(id int);"
checksum := migrationChecksum(alreadySQL)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_already.sql").
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")},
"001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "read migration 001_bad.sql")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) {
t.Run("context_cancelled_while_not_locked", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
err = pgAdvisoryLock(ctx, db)
require.Error(t, err)
require.Contains(t, err.Error(), "acquire migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("unlock_exec_error", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnError(errors.New("unlock failed"))
err = pgAdvisoryUnlock(context.Background(), db)
require.Error(t, err)
require.Contains(t, err.Error(), "release migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("acquire_lock_after_retry", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3)
defer cancel()
start := time.Now()
err = pgAdvisoryLock(ctx, db)
require.NoError(t, err)
require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval)
require.NoError(t, mock.ExpectationsWereMet())
})
}
func migrationChecksum(content string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(content)))
return hex.EncodeToString(sum[:])
}

View File

@@ -0,0 +1,164 @@
package repository
import (
"context"
"database/sql"
"testing"
"testing/fstest"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestValidateMigrationExecutionMode(t *testing.T) {
t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移禁止事务控制语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", `
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a);
DROP INDEX CONCURRENTLY IF EXISTS idx_b;
`)
require.True(t, nonTx)
require.NoError(t, err)
})
}
func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_idx_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_idx_notx.sql": &fstest.MapFile{
Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_multi_idx_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_multi_idx_notx.sql": &fstest.MapFile{
Data: []byte(`
-- first
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);
-- second
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_col.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectBegin()
mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_col.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_col.sql": &fstest.MapFile{
Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) {
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
}

View File

@@ -42,12 +42,19 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// usage_logs: billing_type used by filters/stats
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
// security_secrets table should exist
var securitySecretsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass))
require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist")
// user_allowed_groups table should exist
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))

View File

@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/url"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -21,16 +22,23 @@ type openaiOAuthService struct {
tokenURL string
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
client, err := createOpenAIReqClient(proxyURL)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
}
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
clientID = strings.TrimSpace(clientID)
if clientID == "" {
clientID = openai.ClientID
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", openai.ClientID)
formData.Set("client_id", clientID)
formData.Set("code", code)
formData.Set("redirect_uri", redirectURI)
formData.Set("code_verifier", codeVerifier)
@@ -56,12 +64,28 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
// 调用方应始终传入正确的 client_id为兼容旧数据未指定时默认使用 OpenAI ClientID
clientID = strings.TrimSpace(clientID)
if clientID == "" {
clientID = openai.ClientID
}
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client, err := createOpenAIReqClient(proxyURL)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
}
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID)
formData.Set("client_id", clientID)
formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse
@@ -84,7 +108,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,

View File

@@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
@@ -136,13 +136,84 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID
// 且只发送一次请求(不再盲猜多个 client_id
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
require.Equal(s.T(), "at", resp.AccessToken)
// 只发送了一次请求,使用默认的 OpenAI ClientID
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
}
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID != customClientID {
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-custom", resp.AccessToken)
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "bad")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
@@ -152,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
@@ -169,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
done := make(chan error, 1)
go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
done <- err
}()
@@ -195,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
wantClientID := openai.SoraClientID
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("client_id"); got != wantClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID)
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
@@ -213,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
@@ -229,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
@@ -55,6 +56,10 @@ INSERT INTO ops_error_logs (
upstream_error_message,
upstream_error_detail,
upstream_errors,
auth_latency_ms,
routing_latency_ms,
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms,
request_body,
request_body_truncated,
@@ -64,7 +69,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
) RETURNING id`
var id int64
@@ -97,6 +102,10 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON),
opsNullInt64(input.AuthLatencyMs),
opsNullInt64(input.RoutingLatencyMs),
opsNullInt64(input.UpstreamLatencyMs),
opsNullInt64(input.ResponseLatencyMs),
opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated,
@@ -930,6 +939,243 @@ WHERE id = $1`
return err
}
func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if len(inputs) == 0 {
return 0, nil
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
stmt, err := tx.PrepareContext(ctx, pq.CopyIn(
"ops_system_logs",
"created_at",
"level",
"component",
"message",
"request_id",
"client_request_id",
"user_id",
"account_id",
"platform",
"model",
"extra",
))
if err != nil {
_ = tx.Rollback()
return 0, err
}
var inserted int64
for _, input := range inputs {
if input == nil {
continue
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
component := strings.TrimSpace(input.Component)
level := strings.ToLower(strings.TrimSpace(input.Level))
message := strings.TrimSpace(input.Message)
if level == "" || message == "" {
continue
}
if component == "" {
component = "app"
}
extra := strings.TrimSpace(input.ExtraJSON)
if extra == "" {
extra = "{}"
}
if _, err := stmt.ExecContext(
ctx,
createdAt.UTC(),
level,
component,
message,
opsNullString(input.RequestID),
opsNullString(input.ClientRequestID),
opsNullInt64(input.UserID),
opsNullInt64(input.AccountID),
opsNullString(input.Platform),
opsNullString(input.Model),
extra,
); err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return inserted, err
}
inserted++
}
if _, err := stmt.ExecContext(ctx); err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return inserted, err
}
if err := stmt.Close(); err != nil {
_ = tx.Rollback()
return inserted, err
}
if err := tx.Commit(); err != nil {
return inserted, err
}
return inserted, nil
}
func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsSystemLogFilter{}
}
page := filter.Page
if page <= 0 {
page = 1
}
pageSize := filter.PageSize
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 200 {
pageSize = 200
}
where, args, _ := buildOpsSystemLogsWhere(filter)
countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where
var total int
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
return nil, err
}
offset := (page - 1) * pageSize
argsWithLimit := append(args, pageSize, offset)
query := `
SELECT
l.id,
l.created_at,
l.level,
COALESCE(l.component, ''),
COALESCE(l.message, ''),
COALESCE(l.request_id, ''),
COALESCE(l.client_request_id, ''),
l.user_id,
l.account_id,
COALESCE(l.platform, ''),
COALESCE(l.model, ''),
COALESCE(l.extra::text, '{}')
FROM ops_system_logs l
` + where + `
ORDER BY l.created_at DESC, l.id DESC
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
rows, err := r.db.QueryContext(ctx, query, argsWithLimit...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
logs := make([]*service.OpsSystemLog, 0, pageSize)
for rows.Next() {
item := &service.OpsSystemLog{}
var userID sql.NullInt64
var accountID sql.NullInt64
var extraRaw string
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&item.Level,
&item.Component,
&item.Message,
&item.RequestID,
&item.ClientRequestID,
&userID,
&accountID,
&item.Platform,
&item.Model,
&extraRaw,
); err != nil {
return nil, err
}
if userID.Valid {
v := userID.Int64
item.UserID = &v
}
if accountID.Valid {
v := accountID.Int64
item.AccountID = &v
}
extraRaw = strings.TrimSpace(extraRaw)
if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" {
extra := make(map[string]any)
if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil {
item.Extra = extra
}
}
logs = append(logs, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return &service.OpsSystemLogList{
Logs: logs,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) {
if r == nil || r.db == nil {
return 0, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsSystemLogCleanupFilter{}
}
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
if !hasConstraint {
return 0, fmt.Errorf("cleanup requires at least one filter condition")
}
query := "DELETE FROM ops_system_logs l " + where
res, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if input == nil {
return fmt.Errorf("nil input")
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
_, err := r.db.ExecContext(ctx, `
INSERT INTO ops_system_log_cleanup_audits (
created_at,
operator_id,
conditions,
deleted_rows
) VALUES ($1,$2,$3,$4)
`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows)
return err
}
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses := make([]string, 0, 12)
args := make([]any, 0, 12)
@@ -948,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if phaseFilter != "upstream" {
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400")
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
@@ -962,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
if p := strings.TrimSpace(filter.Platform); p != "" {
args = append(args, p)
clauses = append(clauses, "platform = $"+itoa(len(args)))
clauses = append(clauses, "e.platform = $"+itoa(len(args)))
}
if filter.GroupID != nil && *filter.GroupID > 0 {
args = append(args, *filter.GroupID)
clauses = append(clauses, "group_id = $"+itoa(len(args)))
clauses = append(clauses, "e.group_id = $"+itoa(len(args)))
}
if filter.AccountID != nil && *filter.AccountID > 0 {
args = append(args, *filter.AccountID)
clauses = append(clauses, "account_id = $"+itoa(len(args)))
clauses = append(clauses, "e.account_id = $"+itoa(len(args)))
}
if phase := phaseFilter; phase != "" {
args = append(args, phase)
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
clauses = append(clauses, "e.error_phase = $"+itoa(len(args)))
}
if filter != nil {
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
args = append(args, owner)
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args)))
}
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
args = append(args, source)
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args)))
}
}
if resolvedFilter != nil {
args = append(args, *resolvedFilter)
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args)))
}
// View filter: errors vs excluded vs all.
@@ -1000,51 +1246,140 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
switch view {
case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
case "excluded":
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true")
case "all":
// no-op
default:
// treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
}
if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes))
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")")
} else if filter.StatusCodesOther {
// "Other" means: status codes not in the common list.
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
args = append(args, pq.Array(known))
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))")
}
// Exact correlation keys (preferred for request↔upstream linkage).
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
args = append(args, rid)
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args)))
}
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
args = append(args, crid)
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args)))
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")")
}
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
like := "%" + userQuery + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "u.email ILIKE $"+n)
clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")")
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) {
clauses := make([]string, 0, 10)
args := make([]any, 0, 10)
clauses = append(clauses, "1=1")
hasConstraint := false
if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, filter.StartTime.UTC())
clauses = append(clauses, "l.created_at >= $"+itoa(len(args)))
hasConstraint = true
}
if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, filter.EndTime.UTC())
clauses = append(clauses, "l.created_at < $"+itoa(len(args)))
hasConstraint = true
}
if filter != nil {
if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" {
args = append(args, v)
clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Component); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.RequestID); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.ClientRequestID); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args)))
hasConstraint = true
}
if filter.UserID != nil && *filter.UserID > 0 {
args = append(args, *filter.UserID)
clauses = append(clauses, "l.user_id = $"+itoa(len(args)))
hasConstraint = true
}
if filter.AccountID != nil && *filter.AccountID > 0 {
args = append(args, *filter.AccountID)
clauses = append(clauses, "l.account_id = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Platform); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Model); v != "" {
args = append(args, v)
clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args)))
hasConstraint = true
}
if v := strings.TrimSpace(filter.Query); v != "" {
like := "%" + v + "%"
args = append(args, like)
n := itoa(len(args))
clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")")
hasConstraint = true
}
}
return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint
}
func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) {
if filter == nil {
filter = &service.OpsSystemLogCleanupFilter{}
}
listFilter := &service.OpsSystemLogFilter{
StartTime: filter.StartTime,
EndTime: filter.EndTime,
Level: filter.Level,
Component: filter.Component,
RequestID: filter.RequestID,
ClientRequestID: filter.ClientRequestID,
UserID: filter.UserID,
AccountID: filter.AccountID,
Platform: filter.Platform,
Model: filter.Model,
Query: filter.Query,
}
return buildOpsSystemLogsWhere(listFilter)
}
// Helpers for nullable args
func opsNullString(v any) any {
switch s := v.(type) {

View File

@@ -12,6 +12,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
opsRawLatencyQueryTimeout = 2 * time.Second
opsRawPeakQueryTimeout = 1500 * time.Millisecond
)
func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
@@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic
func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
degraded := false
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
if err != nil {
return nil, err
}
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
cancelLatency()
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
duration = service.OpsPercentiles{}
ttft = service.OpsPercentiles{}
} else {
return nil, err
}
}
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
@@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
cancelPeak()
if err != nil {
return nil, err
}
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
if degraded {
if qpsCurrent <= 0 {
qpsCurrent = qpsAvg
}
if tpsCurrent <= 0 {
tpsCurrent = tpsAvg
}
if qpsPeak <= 0 {
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
}
if tpsPeak <= 0 {
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
}
}
return &service.OpsDashboardOverview{
StartTime: start,
@@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f
sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA))
errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA))
upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA))
degraded := false
// Keep "current" rates as raw, to preserve realtime semantics.
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
// NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont
// and keeps semantics consistent across modes.
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
cancelPeak()
if err != nil {
return nil, err
}
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
if degraded {
if qpsCurrent <= 0 {
qpsCurrent = qpsAvg
}
if tpsCurrent <= 0 {
tpsCurrent = tpsAvg
}
if qpsPeak <= 0 {
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
}
if tpsPeak <= 0 {
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
}
}
return &service.OpsDashboardOverview{
StartTime: start,
@@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops
return nil, err
}
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
cancelLatency()
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
duration = service.OpsPercentiles{}
ttft = service.OpsPercentiles{}
} else {
return nil, err
}
}
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
@@ -735,68 +795,56 @@ FROM usage_logs ul
}
func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) {
{
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99,
AVG(duration_ms) AS avg_ms,
MAX(duration_ms) AS max_ms
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99,
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg,
MAX(duration_ms) AS duration_max,
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99,
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg,
MAX(first_token_ms) AS ttft_max
FROM usage_logs ul
` + join + `
` + where + `
AND duration_ms IS NOT NULL`
` + where
var p50, p90, p95, p99 sql.NullFloat64
var avg sql.NullFloat64
var max sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
duration.P50 = floatToIntPtr(p50)
duration.P90 = floatToIntPtr(p90)
duration.P95 = floatToIntPtr(p95)
duration.P99 = floatToIntPtr(p99)
duration.Avg = floatToIntPtr(avg)
if max.Valid {
v := int(max.Int64)
duration.Max = &v
}
var dP50, dP90, dP95, dP99 sql.NullFloat64
var dAvg sql.NullFloat64
var dMax sql.NullInt64
var tP50, tP90, tP95, tP99 sql.NullFloat64
var tAvg sql.NullFloat64
var tMax sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
&dP50, &dP90, &dP95, &dP99, &dAvg, &dMax,
&tP50, &tP90, &tP95, &tP99, &tAvg, &tMax,
); err != nil {
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
{
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99,
AVG(first_token_ms) AS avg_ms,
MAX(first_token_ms) AS max_ms
FROM usage_logs ul
` + join + `
` + where + `
AND first_token_ms IS NOT NULL`
duration.P50 = floatToIntPtr(dP50)
duration.P90 = floatToIntPtr(dP90)
duration.P95 = floatToIntPtr(dP95)
duration.P99 = floatToIntPtr(dP99)
duration.Avg = floatToIntPtr(dAvg)
if dMax.Valid {
v := int(dMax.Int64)
duration.Max = &v
}
var p50, p90, p95, p99 sql.NullFloat64
var avg sql.NullFloat64
var max sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
ttft.P50 = floatToIntPtr(p50)
ttft.P90 = floatToIntPtr(p90)
ttft.P95 = floatToIntPtr(p95)
ttft.P99 = floatToIntPtr(p99)
ttft.Avg = floatToIntPtr(avg)
if max.Valid {
v := int(max.Int64)
ttft.Max = &v
}
ttft.P50 = floatToIntPtr(tP50)
ttft.P90 = floatToIntPtr(tP90)
ttft.P95 = floatToIntPtr(tP95)
ttft.P99 = floatToIntPtr(tP99)
ttft.Avg = floatToIntPtr(tAvg)
if tMax.Valid {
v := int(tMax.Int64)
ttft.Max = &v
}
return duration, ttft, nil
@@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O
return qpsCurrent, tpsCurrent, nil
}
func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) {
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
q := `
WITH usage_buckets AS (
SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COUNT(*) AS req_cnt,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt
FROM usage_logs ul
` + usageJoin + `
` + usageWhere + `
GROUP BY 1
),
error_buckets AS (
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt
FROM ops_error_logs
` + errorWhere + `
AND COALESCE(status_code, 0) >= 400
@@ -875,47 +926,33 @@ error_buckets AS (
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total
COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req,
COALESCE(u.token_cnt, 0) AS total_tokens
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
)
SELECT COALESCE(MAX(total), 0) FROM combined`
SELECT
COALESCE(MAX(total_req), 0) AS max_req_per_min,
COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min
FROM combined`
args := append(usageArgs, errorArgs...)
var maxPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
return 0, err
var maxReqPerMinute, maxTokensPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil {
return 0, 0, err
}
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
return 0, nil
if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 {
qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0)
}
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 {
tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0)
}
return qpsPeak, tpsPeak, nil
}
func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT COALESCE(MAX(tokens_per_min), 0)
FROM (
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min
FROM usage_logs ul
` + join + `
` + where + `
GROUP BY 1
) t`
var maxPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
return 0, err
}
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
return 0, nil
}
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
func isQueryTimeoutErr(err error) bool {
return errors.Is(err, context.DeadlineExceeded)
}
func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) {

View File

@@ -0,0 +1,22 @@
package repository
import (
"context"
"fmt"
"testing"
)
func TestIsQueryTimeoutErr(t *testing.T) {
if !isQueryTimeoutErr(context.DeadlineExceeded) {
t.Fatalf("context.DeadlineExceeded should be treated as query timeout")
}
if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) {
t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout")
}
if isQueryTimeoutErr(context.Canceled) {
t.Fatalf("context.Canceled should not be treated as query timeout")
}
if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) {
t.Fatalf("wrapped context.Canceled should not be treated as query timeout")
}
}

View File

@@ -0,0 +1,48 @@
package repository
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
filter := &service.OpsErrorLogFilter{
Query: "ACCESS_DENIED",
}
where, args := buildOpsErrorLogsWhere(filter)
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 1 {
t.Fatalf("args len = %d, want 1", len(args))
}
if !strings.Contains(where, "e.request_id ILIKE $") {
t.Fatalf("where should include qualified request_id condition: %s", where)
}
if !strings.Contains(where, "e.client_request_id ILIKE $") {
t.Fatalf("where should include qualified client_request_id condition: %s", where)
}
if !strings.Contains(where, "e.error_message ILIKE $") {
t.Fatalf("where should include qualified error_message condition: %s", where)
}
}
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
filter := &service.OpsErrorLogFilter{
UserQuery: "admin@",
}
where, args := buildOpsErrorLogsWhere(filter)
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 1 {
t.Fatalf("args len = %d, want 1", len(args))
}
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
t.Fatalf("where should include EXISTS user email condition: %s", where)
}
}

View File

@@ -35,12 +35,12 @@ func latencyHistogramRangeCaseExpr(column string) string {
if b.upperMs <= 0 {
continue
}
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label))
fmt.Fprintf(&sb, "\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label)
}
// Default bucket.
last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1]
_, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label))
fmt.Fprintf(&sb, "\tELSE '%s'\n", last.label)
_, _ = sb.WriteString("END")
return sb.String()
}
@@ -54,11 +54,11 @@ func latencyHistogramRangeOrderCaseExpr(column string) string {
if b.upperMs <= 0 {
continue
}
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order))
fmt.Fprintf(&sb, "\tWHEN %s < %d THEN %d\n", column, b.upperMs, order)
order++
}
_, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order))
fmt.Fprintf(&sb, "\tELSE %d\n", order)
_, _ = sb.WriteString("END")
return sb.String()
}

View File

@@ -0,0 +1,145 @@
package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetOpenAITokenStats(ctx context.Context, filter *service.OpsOpenAITokenStatsFilter) (*service.OpsOpenAITokenStatsResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
// 允许 start_time == end_time结果为空与 service 层校验口径保持一致。
if filter.StartTime.After(filter.EndTime) {
return nil, fmt.Errorf("start_time must be <= end_time")
}
dashboardFilter := &service.OpsDashboardFilter{
StartTime: filter.StartTime.UTC(),
EndTime: filter.EndTime.UTC(),
Platform: strings.TrimSpace(strings.ToLower(filter.Platform)),
GroupID: filter.GroupID,
}
join, where, baseArgs, next := buildUsageWhere(dashboardFilter, dashboardFilter.StartTime, dashboardFilter.EndTime, 1)
where += " AND ul.model LIKE 'gpt%'"
baseCTE := `
WITH stats AS (
SELECT
ul.model AS model,
COUNT(*)::bigint AS request_count,
ROUND(
AVG(
CASE
WHEN ul.duration_ms > 0 AND ul.output_tokens > 0
THEN ul.output_tokens * 1000.0 / ul.duration_ms
END
)::numeric,
2
)::float8 AS avg_tokens_per_sec,
ROUND(AVG(ul.first_token_ms)::numeric, 2)::float8 AS avg_first_token_ms,
COALESCE(SUM(ul.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(ROUND(AVG(ul.duration_ms)::numeric, 0), 0)::bigint AS avg_duration_ms,
COUNT(CASE WHEN ul.first_token_ms IS NOT NULL THEN 1 END)::bigint AS requests_with_first_token
FROM usage_logs ul
` + join + `
` + where + `
GROUP BY ul.model
)
`
countSQL := baseCTE + `SELECT COUNT(*) FROM stats`
var total int64
if err := r.db.QueryRowContext(ctx, countSQL, baseArgs...).Scan(&total); err != nil {
return nil, err
}
querySQL := baseCTE + `
SELECT
model,
request_count,
avg_tokens_per_sec,
avg_first_token_ms,
total_output_tokens,
avg_duration_ms,
requests_with_first_token
FROM stats
ORDER BY request_count DESC, model ASC`
args := make([]any, 0, len(baseArgs)+2)
args = append(args, baseArgs...)
if filter.IsTopNMode() {
querySQL += fmt.Sprintf("\nLIMIT $%d", next)
args = append(args, filter.TopN)
} else {
offset := (filter.Page - 1) * filter.PageSize
querySQL += fmt.Sprintf("\nLIMIT $%d OFFSET $%d", next, next+1)
args = append(args, filter.PageSize, offset)
}
rows, err := r.db.QueryContext(ctx, querySQL, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
items := make([]*service.OpsOpenAITokenStatsItem, 0, 32)
for rows.Next() {
item := &service.OpsOpenAITokenStatsItem{}
var avgTPS sql.NullFloat64
var avgFirstToken sql.NullFloat64
if err := rows.Scan(
&item.Model,
&item.RequestCount,
&avgTPS,
&avgFirstToken,
&item.TotalOutputTokens,
&item.AvgDurationMs,
&item.RequestsWithFirstToken,
); err != nil {
return nil, err
}
if avgTPS.Valid {
v := avgTPS.Float64
item.AvgTokensPerSec = &v
}
if avgFirstToken.Valid {
v := avgFirstToken.Float64
item.AvgFirstTokenMs = &v
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
resp := &service.OpsOpenAITokenStatsResponse{
TimeRange: strings.TrimSpace(filter.TimeRange),
StartTime: dashboardFilter.StartTime,
EndTime: dashboardFilter.EndTime,
Platform: dashboardFilter.Platform,
GroupID: dashboardFilter.GroupID,
Items: items,
Total: total,
}
if filter.IsTopNMode() {
topN := filter.TopN
resp.TopN = &topN
} else {
resp.Page = filter.Page
resp.PageSize = filter.PageSize
}
return resp, nil
}

View File

@@ -0,0 +1,156 @@
package repository
import (
"context"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestOpsRepositoryGetOpenAITokenStats_PaginationMode(t *testing.T) {
db, mock := newSQLMock(t)
repo := &opsRepository{db: db}
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
groupID := int64(9)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: "1d",
StartTime: start,
EndTime: end,
Platform: " OpenAI ",
GroupID: &groupID,
Page: 2,
PageSize: 10,
}
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
WithArgs(start, end, groupID, "openai").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(3)))
rows := sqlmock.NewRows([]string{
"model",
"request_count",
"avg_tokens_per_sec",
"avg_first_token_ms",
"total_output_tokens",
"avg_duration_ms",
"requests_with_first_token",
}).
AddRow("gpt-4o-mini", int64(20), 21.56, 120.34, int64(3000), int64(850), int64(18)).
AddRow("gpt-4.1", int64(20), 10.2, 240.0, int64(2500), int64(900), int64(20))
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$5 OFFSET \$6`).
WithArgs(start, end, groupID, "openai", 10, 10).
WillReturnRows(rows)
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, int64(3), resp.Total)
require.Equal(t, 2, resp.Page)
require.Equal(t, 10, resp.PageSize)
require.Nil(t, resp.TopN)
require.Equal(t, "openai", resp.Platform)
require.NotNil(t, resp.GroupID)
require.Equal(t, groupID, *resp.GroupID)
require.Len(t, resp.Items, 2)
require.Equal(t, "gpt-4o-mini", resp.Items[0].Model)
require.NotNil(t, resp.Items[0].AvgTokensPerSec)
require.InDelta(t, 21.56, *resp.Items[0].AvgTokensPerSec, 0.0001)
require.NotNil(t, resp.Items[0].AvgFirstTokenMs)
require.InDelta(t, 120.34, *resp.Items[0].AvgFirstTokenMs, 0.0001)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestOpsRepositoryGetOpenAITokenStats_TopNMode(t *testing.T) {
db, mock := newSQLMock(t)
repo := &opsRepository{db: db}
start := time.Date(2026, 1, 1, 10, 0, 0, 0, time.UTC)
end := start.Add(time.Hour)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: "1h",
StartTime: start,
EndTime: end,
TopN: 5,
}
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
WithArgs(start, end).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
rows := sqlmock.NewRows([]string{
"model",
"request_count",
"avg_tokens_per_sec",
"avg_first_token_ms",
"total_output_tokens",
"avg_duration_ms",
"requests_with_first_token",
}).
AddRow("gpt-4o", int64(5), nil, nil, int64(0), int64(0), int64(0))
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3`).
WithArgs(start, end, 5).
WillReturnRows(rows)
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
require.NoError(t, err)
require.NotNil(t, resp)
require.NotNil(t, resp.TopN)
require.Equal(t, 5, *resp.TopN)
require.Equal(t, 0, resp.Page)
require.Equal(t, 0, resp.PageSize)
require.Len(t, resp.Items, 1)
require.Nil(t, resp.Items[0].AvgTokensPerSec)
require.Nil(t, resp.Items[0].AvgFirstTokenMs)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestOpsRepositoryGetOpenAITokenStats_EmptyResult(t *testing.T) {
db, mock := newSQLMock(t)
repo := &opsRepository{db: db}
start := time.Date(2026, 1, 2, 0, 0, 0, 0, time.UTC)
end := start.Add(30 * time.Minute)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: "30m",
StartTime: start,
EndTime: end,
Page: 1,
PageSize: 20,
}
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
WithArgs(start, end).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3 OFFSET \$4`).
WithArgs(start, end, 20, 0).
WillReturnRows(sqlmock.NewRows([]string{
"model",
"request_count",
"avg_tokens_per_sec",
"avg_first_token_ms",
"total_output_tokens",
"avg_duration_ms",
"requests_with_first_token",
}))
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, int64(0), resp.Total)
require.Len(t, resp.Items, 0)
require.Equal(t, 1, resp.Page)
require.Equal(t, 20, resp.PageSize)
require.NoError(t, mock.ExpectationsWereMet())
}

View File

@@ -0,0 +1,86 @@
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) {
start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC)
userID := int64(12)
accountID := int64(34)
filter := &service.OpsSystemLogFilter{
StartTime: &start,
EndTime: &end,
Level: "warn",
Component: "http.access",
RequestID: "req-1",
ClientRequestID: "creq-1",
UserID: &userID,
AccountID: &accountID,
Platform: "openai",
Model: "gpt-5",
Query: "timeout",
}
where, args, hasConstraint := buildOpsSystemLogsWhere(filter)
if !hasConstraint {
t.Fatalf("expected hasConstraint=true")
}
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 11 {
t.Fatalf("args len = %d, want 11", len(args))
}
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
t.Fatalf("where should include client_request_id condition: %s", where)
}
if !contains(where, "l.user_id = $") {
t.Fatalf("where should include user_id condition: %s", where)
}
}
func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) {
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{})
if hasConstraint {
t.Fatalf("expected hasConstraint=false")
}
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 0 {
t.Fatalf("args len = %d, want 0", len(args))
}
}
func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) {
userID := int64(9)
filter := &service.OpsSystemLogCleanupFilter{
ClientRequestID: "creq-9",
UserID: &userID,
}
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
if !hasConstraint {
t.Fatalf("expected hasConstraint=true")
}
if len(args) != 2 {
t.Fatalf("args len = %d, want 2", len(args))
}
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
t.Fatalf("where should include client_request_id condition: %s", where)
}
if !contains(where, "l.user_id = $") {
t.Fatalf("where should include user_id condition: %s", where)
}
}
func contains(s string, sub string) bool {
return strings.Contains(s, sub)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
@@ -16,14 +17,37 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
// pricingRemoteClientError 代理初始化失败时的错误占位客户端
// 所有请求直接返回初始化错误,禁止回退到直连
type pricingRemoteClientError struct {
err error
}
func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) {
return nil, c.err
}
func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) {
return "", c.err
}
// NewPricingRemoteClient 创建定价数据远程客户端
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
// 代理配置失败时行为由 allowDirectOnProxyError 控制:
// - false默认返回错误占位客户端禁止回退到直连
// - true回退到直连仅限管理员显式开启
func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient {
// 安全说明httpclient.GetClient 的错误链url.Parse / proxyutil不含明文代理凭据
// 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err)
return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &pricingRemoteClient{

View File

@@ -19,7 +19,7 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
@@ -140,6 +140,22 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
require.Error(s.T(), err)
}
func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) {
client := NewPricingRemoteClient("://bad", false)
_, ok := client.(*pricingRemoteClientError)
require.True(t, ok, "should return error client when proxy is invalid and fallback disabled")
_, err := client.FetchPricingJSON(context.Background(), "http://example.com")
require.Error(t, err)
require.Contains(t, err.Error(), "proxy client init failed")
}
func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) {
client := NewPricingRemoteClient("://bad", true)
_, ok := client.(*pricingRemoteClient)
require.True(t, ok, "should fallback to direct client when allowed")
}
func TestPricingServiceSuite(t *testing.T) {
suite.Run(t, new(PricingServiceSuite))
}

View File

@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
q = q.Where(promocode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx)
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}

View File

@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure := false
allowPrivate := false
validateResolvedIP := true
maxResponseBytes := defaultProxyProbeResponseMaxBytes
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
}
}
if insecure {
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
maxResponseBytes: maxResponseBytes,
}
}
const (
defaultProxyProbeTimeout = 30 * time.Second
defaultProxyProbeTimeout = 30 * time.Second
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
)
// probeURLs 按优先级排列的探测 URL 列表
@@ -52,6 +58,7 @@ type proxyProbeService struct {
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
maxResponseBytes int64
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
@@ -59,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
ProxyURL: proxyURL,
Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts,
})
@@ -98,10 +104,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
maxResponseBytes := s.maxResponseBytes
if maxResponseBytes <= 0 {
maxResponseBytes = defaultProxyProbeResponseMaxBytes
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
if int64(len(body)) > maxResponseBytes {
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
}
switch parser {
case "ip-api":

View File

@@ -6,6 +6,8 @@ import (
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/imroc/req/v3"
)
@@ -33,11 +35,11 @@ var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建
func getSharedReqClient(opts reqClientOptions) *req.Client {
func getSharedReqClient(opts reqClientOptions) (*req.Client, error) {
key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok {
if c, ok := cached.(*req.Client); ok {
return c
return c, nil
}
}
@@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
if opts.Impersonate {
client = client.ImpersonateChrome()
}
if strings.TrimSpace(opts.ProxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
trimmed, _, err := proxyurl.Parse(opts.ProxyURL)
if err != nil {
return nil, err
}
if trimmed != "" {
client.SetProxyURL(trimmed)
}
actual, _ := sharedReqClients.LoadOrStore(key, client)
if c, ok := actual.(*req.Client); ok {
return c
return c, nil
}
return client
return client, nil
}
func buildReqClientKey(opts reqClientOptions) string {

View File

@@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
clientDefault, err := getSharedReqClient(base)
require.NoError(t, err)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
clientForce, err := getSharedReqClient(force)
require.NoError(t, err)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
@@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
first, err := getSharedReqClient(opts)
require.NoError(t, err)
second, err := getSharedReqClient(opts)
require.NoError(t, err)
require.Same(t, first, second)
}
@@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
client, err := getSharedReqClient(opts)
require.NoError(t, err)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
@@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
client, err := getSharedReqClient(opts)
require.NoError(t, err)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "://missing-scheme",
Timeout: time.Second,
}
_, err := getSharedReqClient(opts)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid proxy URL")
}
func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://",
Timeout: time.Second,
}
_, err := getSharedReqClient(opts)
require.Error(t, err)
require.Contains(t, err.Error(), "proxy URL missing host")
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://proxy.local:8080")
client, err := createOpenAIReqClient("http://proxy.local:8080")
require.NoError(t, err)
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
client, err := createGeminiReqClient("http://proxy.local:8080")
require.NoError(t, err)
require.Equal(t, "", forceHTTPVersion(t, client))
}

View File

@@ -0,0 +1,141 @@
package repository
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// RPM 计数器缓存常量定义
//
// 设计说明:
// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数:
// - Key: rpm:{accountID}:{minuteTimestamp}
// - Value: 当前分钟内的请求计数
// - TTL: 120 秒(覆盖当前分钟 + 一定冗余)
//
// 使用 TxPipelineMULTI/EXEC执行 INCR + EXPIRE保证原子性且兼容 Redis Cluster。
// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。
//
// 设计决策:
// - TxPipeline vs PipelinePipeline 仅合并发送但不保证原子TxPipeline 使用 MULTI/EXEC 事务保证原子执行。
// - rdb.Time() 单独调用Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用2 RTT
// Lua 脚本可以做到 1 RTT但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。
const (
// RPM 计数器键前缀
// 格式: rpm:{accountID}:{minuteTimestamp}
rpmKeyPrefix = "rpm:"
// RPM 计数器 TTL120 秒,覆盖当前分钟窗口 + 冗余)
rpmKeyTTL = 120 * time.Second
)
// RPMCacheImpl RPM 计数器缓存 Redis 实现
type RPMCacheImpl struct {
rdb *redis.Client
}
// NewRPMCache 创建 RPM 计数器缓存
func NewRPMCache(rdb *redis.Client) service.RPMCache {
return &RPMCacheImpl{rdb: rdb}
}
// currentMinuteKey 获取当前分钟的完整 Redis key
// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差
func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) {
serverTime, err := c.rdb.Time(ctx).Result()
if err != nil {
return "", fmt.Errorf("redis TIME: %w", err)
}
minuteTS := serverTime.Unix() / 60
return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil
}
// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用)
// 使用 rdb.Time() 获取 Redis 服务端时间
func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) {
serverTime, err := c.rdb.Time(ctx).Result()
if err != nil {
return "", fmt.Errorf("redis TIME: %w", err)
}
minuteTS := serverTime.Unix() / 60
return strconv.FormatInt(minuteTS, 10), nil
}
// IncrementRPM 原子递增并返回当前分钟的计数
// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE保证原子性且兼容 Redis Cluster
func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) {
key, err := c.currentMinuteKey(ctx, accountID)
if err != nil {
return 0, fmt.Errorf("rpm increment: %w", err)
}
// 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行
// EXPIRE 幂等,每次都设置不影响正确性
pipe := c.rdb.TxPipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, rpmKeyTTL)
if _, err := pipe.Exec(ctx); err != nil {
return 0, fmt.Errorf("rpm increment: %w", err)
}
return int(incrCmd.Val()), nil
}
// GetRPM 获取当前分钟的 RPM 计数
func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) {
key, err := c.currentMinuteKey(ctx, accountID)
if err != nil {
return 0, fmt.Errorf("rpm get: %w", err)
}
val, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil // 当前分钟无记录
}
if err != nil {
return 0, fmt.Errorf("rpm get: %w", err)
}
return val, nil
}
// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline
func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
// 获取当前分钟后缀
minuteSuffix, err := c.currentMinuteSuffix(ctx)
if err != nil {
return nil, fmt.Errorf("rpm batch get: %w", err)
}
// 使用 Pipeline 批量 GET
pipe := c.rdb.Pipeline()
cmds := make(map[int64]*redis.StringCmd, len(accountIDs))
for _, id := range accountIDs {
key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix)
cmds[id] = pipe.Get(ctx, key)
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("rpm batch get: %w", err)
}
result := make(map[int64]int, len(accountIDs))
for id, cmd := range cmds {
if val, err := cmd.Int(); err == nil {
result[id] = val
} else {
result[id] = 0
}
}
return result, nil
}

View File

@@ -0,0 +1,177 @@
package repository
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const (
securitySecretKeyJWT = "jwt_secret"
securitySecretReadRetryMax = 5
securitySecretReadRetryWait = 10 * time.Millisecond
)
var readRandomBytes = rand.Read
func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error {
if client == nil {
return fmt.Errorf("nil ent client")
}
if cfg == nil {
return fmt.Errorf("nil config")
}
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
if cfg.JWT.Secret != "" {
storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret)
if err != nil {
return fmt.Errorf("persist jwt secret: %w", err)
}
if storedSecret != cfg.JWT.Secret {
log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.")
}
cfg.JWT.Secret = storedSecret
return nil
}
secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32)
if err != nil {
return fmt.Errorf("ensure jwt secret: %w", err)
}
cfg.JWT.Secret = secret
if created {
log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.")
}
return nil
}
func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) {
existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
if err == nil {
value := strings.TrimSpace(existing.Value)
if len([]byte(value)) < 32 {
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return value, false, nil
}
if !ent.IsNotFound(err) {
return "", false, err
}
generated, err := generateHexSecret(byteLength)
if err != nil {
return "", false, err
}
if err := client.SecuritySecret.Create().
SetKey(key).
SetValue(generated).
OnConflictColumns(securitysecret.FieldKey).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return "", false, err
}
}
stored, err := querySecuritySecretWithRetry(ctx, client, key)
if err != nil {
return "", false, err
}
value := strings.TrimSpace(stored.Value)
if len([]byte(value)) < 32 {
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return value, value == generated, nil
}
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) {
value = strings.TrimSpace(value)
if len([]byte(value)) < 32 {
return "", fmt.Errorf("secret %q must be at least 32 bytes", key)
}
if err := client.SecuritySecret.Create().
SetKey(key).
SetValue(value).
OnConflictColumns(securitysecret.FieldKey).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return "", err
}
}
stored, err := querySecuritySecretWithRetry(ctx, client, key)
if err != nil {
return "", err
}
storedValue := strings.TrimSpace(stored.Value)
if len([]byte(storedValue)) < 32 {
return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return storedValue, nil
}
func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) {
var lastErr error
for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ {
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
if err == nil {
return stored, nil
}
if !isSecretNotFoundError(err) {
return nil, err
}
lastErr = err
if attempt == securitySecretReadRetryMax {
break
}
timer := time.NewTimer(securitySecretReadRetryWait)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
}
return nil, lastErr
}
func isSecretNotFoundError(err error) bool {
if err == nil {
return false
}
return ent.IsNotFound(err) || isSQLNoRowsError(err)
}
func isSQLNoRowsError(err error) bool {
if err == nil {
return false
}
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
}
func generateHexSecret(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 32
}
buf := make([]byte, byteLength)
if _, err := readRandomBytes(buf); err != nil {
return "", fmt.Errorf("generate random secret: %w", err)
}
return hex.EncodeToString(buf), nil
}

View File

@@ -0,0 +1,337 @@
package repository
import (
"context"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"strings"
"sync"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newSecuritySecretTestClient(t *testing.T) *dbent.Client {
t.Helper()
name := strings.ReplaceAll(t.Name(), "/", "_")
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", name)
db, err := sql.Open("sqlite", dsn)
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
func TestEnsureBootstrapSecretsNilInputs(t *testing.T) {
err := ensureBootstrapSecrets(context.Background(), nil, &config.Config{})
require.Error(t, err)
require.Contains(t, err.Error(), "nil ent client")
client := newSecuritySecretTestClient(t)
err = ensureBootstrapSecrets(context.Background(), client, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "nil config")
}
func TestEnsureBootstrapSecretsGenerateAndPersistJWTSecret(t *testing.T) {
client := newSecuritySecretTestClient(t)
cfg := &config.Config{}
err := ensureBootstrapSecrets(context.Background(), client, cfg)
require.NoError(t, err)
require.NotEmpty(t, cfg.JWT.Secret)
require.GreaterOrEqual(t, len([]byte(cfg.JWT.Secret)), 32)
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
require.NoError(t, err)
require.Equal(t, cfg.JWT.Secret, stored.Value)
}
func TestEnsureBootstrapSecretsLoadExistingJWTSecret(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("existing-jwt-secret-32bytes-long!!!!").Save(context.Background())
require.NoError(t, err)
cfg := &config.Config{}
err = ensureBootstrapSecrets(context.Background(), client, cfg)
require.NoError(t, err)
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
}
func TestEnsureBootstrapSecretsRejectInvalidStoredSecret(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("too-short").Save(context.Background())
require.NoError(t, err)
cfg := &config.Config{}
err = ensureBootstrapSecrets(context.Background(), client, cfg)
require.Error(t, err)
require.Contains(t, err.Error(), "at least 32 bytes")
}
func TestEnsureBootstrapSecretsPersistConfiguredJWTSecret(t *testing.T) {
client := newSecuritySecretTestClient(t)
cfg := &config.Config{
JWT: config.JWTConfig{Secret: "configured-jwt-secret-32bytes-long!!"},
}
err := ensureBootstrapSecrets(context.Background(), client, cfg)
require.NoError(t, err)
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
require.NoError(t, err)
require.Equal(t, "configured-jwt-secret-32bytes-long!!", stored.Value)
}
func TestEnsureBootstrapSecretsConfiguredSecretTooShort(t *testing.T) {
client := newSecuritySecretTestClient(t)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "short"}}
err := ensureBootstrapSecrets(context.Background(), client, cfg)
require.Error(t, err)
require.Contains(t, err.Error(), "at least 32 bytes")
}
func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := client.SecuritySecret.Create().
SetKey(securitySecretKeyJWT).
SetValue("existing-jwt-secret-32bytes-long!!!!").
Save(context.Background())
require.NoError(t, err)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "another-configured-jwt-secret-32!!!!"}}
err = ensureBootstrapSecrets(context.Background(), client, cfg)
require.NoError(t, err)
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
require.NoError(t, err)
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value)
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
}
func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := client.SecuritySecret.Create().
SetKey("trimmed_key").
SetValue(" existing-trimmed-secret-32bytes-long!! ").
Save(context.Background())
require.NoError(t, err)
value, created, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "trimmed_key", 32)
require.NoError(t, err)
require.False(t, created)
require.Equal(t, "existing-trimmed-secret-32bytes-long!!", value)
}
func TestGetOrCreateGeneratedSecuritySecretQueryError(t *testing.T) {
client := newSecuritySecretTestClient(t)
require.NoError(t, client.Close())
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "closed_client_key", 32)
require.Error(t, err)
}
func TestGetOrCreateGeneratedSecuritySecretCreateValidationError(t *testing.T) {
client := newSecuritySecretTestClient(t)
tooLongKey := strings.Repeat("k", 101)
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, tooLongKey, 32)
require.Error(t, err)
}
func TestGetOrCreateGeneratedSecuritySecretConcurrentCreation(t *testing.T) {
client := newSecuritySecretTestClient(t)
const goroutines = 8
key := "concurrent_bootstrap_key"
values := make([]string, goroutines)
createdFlags := make([]bool, goroutines)
errs := make([]error, goroutines)
var wg sync.WaitGroup
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
values[idx], createdFlags[idx], errs[idx] = getOrCreateGeneratedSecuritySecret(context.Background(), client, key, 32)
}(i)
}
wg.Wait()
for i := range errs {
require.NoError(t, errs[i])
require.NotEmpty(t, values[i])
}
for i := 1; i < len(values); i++ {
require.Equal(t, values[0], values[i])
}
createdCount := 0
for _, created := range createdFlags {
if created {
createdCount++
}
}
require.GreaterOrEqual(t, createdCount, 1)
require.LessOrEqual(t, createdCount, 1)
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Count(context.Background())
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) {
client := newSecuritySecretTestClient(t)
originalRead := readRandomBytes
readRandomBytes = func([]byte) (int, error) {
return 0, errors.New("boom")
}
t.Cleanup(func() {
readRandomBytes = originalRead
})
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "gen_error_key", 32)
require.Error(t, err)
require.Contains(t, err.Error(), "boom")
}
func TestCreateSecuritySecretIfAbsent(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short")
require.Error(t, err)
require.Contains(t, err.Error(), "at least 32 bytes")
stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long")
require.NoError(t, err)
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes")
require.NoError(t, err)
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background())
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := createSecuritySecretIfAbsent(
context.Background(),
client,
strings.Repeat("k", 101),
"valid-jwt-secret-value-32bytes-long",
)
require.Error(t, err)
}
func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) {
client := newSecuritySecretTestClient(t)
require.NoError(t, client.Close())
_, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long")
require.Error(t, err)
}
func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) {
client := newSecuritySecretTestClient(t)
created, err := client.SecuritySecret.Create().
SetKey("retry_success_key").
SetValue("retry-success-jwt-secret-value-32!!").
Save(context.Background())
require.NoError(t, err)
got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key")
require.NoError(t, err)
require.Equal(t, created.ID, got.ID)
require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value)
}
func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key")
require.Error(t, err)
require.True(t, isSecretNotFoundError(err))
}
func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) {
client := newSecuritySecretTestClient(t)
ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2)
defer cancel()
_, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key")
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) {
client := newSecuritySecretTestClient(t)
require.NoError(t, client.Close())
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key")
require.Error(t, err)
require.False(t, isSecretNotFoundError(err))
}
func TestSecretNotFoundHelpers(t *testing.T) {
require.False(t, isSecretNotFoundError(nil))
require.False(t, isSQLNoRowsError(nil))
require.True(t, isSQLNoRowsError(sql.ErrNoRows))
require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows)))
require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set")))
require.True(t, isSecretNotFoundError(sql.ErrNoRows))
require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set")))
require.False(t, isSecretNotFoundError(errors.New("some other error")))
}
func TestGenerateHexSecretReadError(t *testing.T) {
originalRead := readRandomBytes
readRandomBytes = func([]byte) (int, error) {
return 0, errors.New("read random failed")
}
t.Cleanup(func() {
readRandomBytes = originalRead
})
_, err := generateHexSecret(32)
require.Error(t, err)
require.Contains(t, err.Error(), "read random failed")
}
func TestGenerateHexSecretLengths(t *testing.T) {
v1, err := generateHexSecret(0)
require.NoError(t, err)
require.Len(t, v1, 64)
_, err = hex.DecodeString(v1)
require.NoError(t, err)
v2, err := generateHexSecret(16)
require.NoError(t, err)
require.Len(t, v2, 32)
_, err = hex.DecodeString(v2)
require.NoError(t, err)
require.NotEqual(t, v1, v2)
}

View File

@@ -122,7 +122,7 @@ func (s *SettingRepoSuite) TestSet_EmptyValue() {
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
// 模拟保存站点设置,部分字段有值,部分字段为空
settings := map[string]string{
"site_name": "AICodex2API",
"site_name": "Sub2api",
"site_subtitle": "Subscription to API",
"site_logo": "", // 用户未上传Logo
"api_base_url": "", // 用户未设置API地址
@@ -136,7 +136,7 @@ func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
s.Require().Equal("AICodex2API", result["site_name"])
s.Require().Equal("Sub2api", result["site_name"])
s.Require().Equal("Subscription to API", result["site_subtitle"])
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")

View File

@@ -41,7 +41,7 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewAPIKeyRepository(client)
repo := NewAPIKeyRepository(client, integrationDB)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
@@ -73,7 +73,7 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewAPIKeyRepository(client)
repo := NewAPIKeyRepository(client, integrationDB)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
@@ -93,7 +93,7 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewAPIKeyRepository(client)
repo := NewAPIKeyRepository(client, integrationDB)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),

View File

@@ -0,0 +1,98 @@
package repository
import (
"context"
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
//
// 设计说明:
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
// - 与 accounts 主表通过外键关联ON DELETE CASCADE 确保级联删除
type soraAccountRepository struct {
sql *sql.DB
}
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
return &soraAccountRepository{sql: sqlDB}
}
// Upsert 创建或更新 Sora 账号扩展信息
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
accessToken, accessOK := updates["access_token"].(string)
refreshToken, refreshOK := updates["refresh_token"].(string)
sessionToken, sessionOK := updates["session_token"].(string)
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
if !sessionOK {
return errors.New("缺少 access_token/refresh_token且未提供可更新字段")
}
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_accounts
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
updated_at = NOW()
WHERE account_id = $1
`, accountID, sessionToken)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
}
return nil
}
_, err := r.sql.ExecContext(ctx, `
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (account_id) DO UPDATE SET
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
updated_at = NOW()
`, accountID, accessToken, refreshToken, sessionToken)
return err
}
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
FROM sora_accounts
WHERE account_id = $1
`, accountID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, nil // 记录不存在
}
var sa service.SoraAccount
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
return nil, err
}
return &sa, nil
}
// Delete 删除 Sora 账号扩展信息
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
_, err := r.sql.ExecContext(ctx, `
DELETE FROM sora_accounts WHERE account_id = $1
`, accountID)
return err
}

View File

@@ -0,0 +1,419 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
// 使用原生 SQL 操作 sora_generations 表。
type soraGenerationRepository struct {
sql *sql.DB
}
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
return &soraGenerationRepository{sql: sqlDB}
}
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
err := r.sql.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt)
return err
}
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
func (r *soraGenerationRepository) CreatePendingWithLimit(
ctx context.Context,
gen *service.SoraGeneration,
activeStatuses []string,
maxActive int64,
) error {
if gen == nil {
return fmt.Errorf("generation is nil")
}
if maxActive <= 0 {
return r.Create(ctx, gen)
}
if len(activeStatuses) == 0 {
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
}
tx, err := r.sql.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
return err
}
placeholders := make([]string, len(activeStatuses))
args := make([]any, 0, 1+len(activeStatuses))
args = append(args, gen.UserID)
for i, s := range activeStatuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
countQuery := fmt.Sprintf(
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
strings.Join(placeholders, ","),
)
var activeCount int64
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
return err
}
if activeCount >= maxActive {
return service.ErrSoraGenerationConcurrencyLimit
}
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
if err := tx.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
return err
}
return tx.Commit()
}
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
err := r.sql.QueryRowContext(ctx, `
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations WHERE id = $1
`, id).Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("生成记录不存在")
}
return nil, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
return gen, nil
}
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
var completedAt *time.Time
if gen.CompletedAt != nil {
completedAt = gen.CompletedAt
}
_, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations SET
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
error_message = $9, completed_at = $10
WHERE id = $1
`,
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
gen.ErrorMessage, completedAt,
)
return err
}
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, upstream_task_id = $3
WHERE id = $1 AND status = $4
`,
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
func (r *soraGenerationRepository) UpdateCompletedIfActive(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
completedAt time.Time,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
media_url = $3,
media_urls = $4,
file_size_bytes = $5,
storage_type = $6,
s3_object_keys = $7,
error_message = '',
completed_at = $8
WHERE id = $1 AND status IN ($9, $10)
`,
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
func (r *soraGenerationRepository) UpdateFailedIfActive(
ctx context.Context,
id int64,
errMsg string,
completedAt time.Time,
) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
error_message = $3,
completed_at = $4
WHERE id = $1 AND status IN ($5, $6)
`,
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, completed_at = $3
WHERE id = $1 AND status IN ($4, $5)
`,
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET media_url = $2,
media_urls = $3,
file_size_bytes = $4,
storage_type = $5,
s3_object_keys = $6
WHERE id = $1 AND status = $7
`,
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
return err
}
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
// 构建 WHERE 条件
conditions := []string{"user_id = $1"}
args := []any{params.UserID}
argIdx := 2
if params.Status != "" {
// 支持逗号分隔的多状态
statuses := strings.Split(params.Status, ",")
placeholders := make([]string, len(statuses))
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if params.StorageType != "" {
storageTypes := strings.Split(params.StorageType, ",")
placeholders := make([]string, len(storageTypes))
for i, s := range storageTypes {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
}
if params.MediaType != "" {
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
args = append(args, params.MediaType)
argIdx++
}
whereClause := "WHERE " + strings.Join(conditions, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
// 分页查询
offset := (params.Page - 1) * params.PageSize
listQuery := fmt.Sprintf(`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
args = append(args, params.PageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, err
}
defer func() {
_ = rows.Close()
}()
var results []*service.SoraGeneration
for rows.Next() {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
if err := rows.Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
); err != nil {
return nil, 0, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
results = append(results, gen)
}
return results, total, rows.Err()
}
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
if len(statuses) == 0 {
return 0, nil
}
placeholders := make([]string, len(statuses))
args := []any{userID}
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}

View File

@@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any)
idx++
}
}
if filters.Stream != nil {
if filters.RequestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType)
conditions = append(conditions, condition)
args = append(args, conditionArgs...)
idx += len(conditionArgs)
} else if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
args = append(args, *filters.Stream)
idx++

View File

@@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) {
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
}
func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeWSV2)
stream := false
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
RequestType: &requestType,
Stream: &stream,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where)
require.Equal(t, []any{start, end, requestType}, args)
}
func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeStream)
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
RequestType: &requestType,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where)
require.Equal(t, []any{start, end, requestType}, args)
}
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)

File diff suppressed because it is too large Load Diff

View File

@@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "gpt-5.3-codex",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 1.0,
OpenAIWSMode: true,
CreatedAt: timezone.Today().Add(3 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().True(got.OpenAIWSMode)
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "gpt-5.3-codex",
RequestType: service.RequestTypeWSV2,
Stream: true,
OpenAIWSMode: false,
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 1.0,
CreatedAt: timezone.Today().Add(4 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().Equal(service.RequestTypeWSV2, got.RequestType)
s.Require().True(got.Stream)
s.Require().True(got.OpenAIWSMode)
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
@@ -648,7 +704,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{})
s.Require().NoError(err, "GetBatchUserUsageStats")
s.Require().Len(stats, 2)
s.Require().NotNil(stats[user1.ID])
@@ -656,7 +712,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
}
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
@@ -672,13 +728,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{})
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
@@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
@@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
@@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}

View File

@@ -0,0 +1,328 @@
package repository
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-1",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1,
ActualCost: 1,
BillingType: service.BillingTypeBalance,
RequestType: service.RequestTypeWSV2,
Stream: false,
OpenAIWSMode: false,
CreatedAt: createdAt,
}
mock.ExpectQuery("INSERT INTO usage_logs").
WithArgs(
log.UserID,
log.APIKeyID,
log.AccountID,
log.RequestID,
log.Model,
sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
log.RateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
int16(service.RequestTypeWSV2),
true,
true,
sqlmock.AnyArg(), // duration_ms
sqlmock.AnyArg(), // first_token_ms
sqlmock.AnyArg(), // user_agent
sqlmock.AnyArg(), // ip_address
log.ImageCount,
sqlmock.AnyArg(), // image_size
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // reasoning_effort
log.CacheTTLOverridden,
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
inserted, err := repo.Create(context.Background(), log)
require.NoError(t, err)
require.True(t, inserted)
require.Equal(t, int64(99), log.ID)
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
require.True(t, log.Stream)
require.True(t, log.OpenAIWSMode)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
requestType := int16(service.RequestTypeWSV2)
stream := false
filters := usagestats.UsageLogFilters{
RequestType: &requestType,
Stream: &stream,
ExactTotal: true,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(requestType).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3").
WithArgs(requestType, 20, 0).
WillReturnRows(sqlmock.NewRows([]string{"id"}))
logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters)
require.NoError(t, err)
require.Empty(t, logs)
require.NotNil(t, page)
require.Equal(t, int64(0), page.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeStream)
stream := true
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
require.NoError(t, err)
require.Empty(t, trend)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeWSV2)
stream := false
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err)
require.Empty(t, stats)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
requestType := int16(service.RequestTypeSync)
stream := true
filters := usagestats.UsageLogFilters{
RequestType: &requestType,
Stream: &stream,
}
mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(requestType).
WillReturnRows(sqlmock.NewRows([]string{
"total_requests",
"total_input_tokens",
"total_output_tokens",
"total_cache_tokens",
"total_cost",
"total_actual_cost",
"total_account_cost",
"avg_duration_ms",
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err)
require.Equal(t, int64(1), stats.TotalRequests)
require.Equal(t, int64(9), stats.TotalTokens)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
tests := []struct {
name string
request int16
wantWhere string
wantArg int16
}{
{
name: "sync_with_legacy_fallback",
request: int16(service.RequestTypeSync),
wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))",
wantArg: int16(service.RequestTypeSync),
},
{
name: "stream_with_legacy_fallback",
request: int16(service.RequestTypeStream),
wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))",
wantArg: int16(service.RequestTypeStream),
},
{
name: "ws_v2_with_legacy_fallback",
request: int16(service.RequestTypeWSV2),
wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))",
wantArg: int16(service.RequestTypeWSV2),
},
{
name: "invalid_request_type_normalized_to_unknown",
request: int16(99),
wantWhere: "request_type = $3",
wantArg: int16(service.RequestTypeUnknown),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
where, args := buildRequestTypeFilterCondition(3, tt.request)
require.Equal(t, tt.wantWhere, where)
require.Equal(t, []any{tt.wantArg}, args)
})
}
}
type usageLogScannerStub struct {
values []any
}
func (s usageLogScannerStub) Scan(dest ...any) error {
if len(dest) != len(s.values) {
return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values))
}
for i := range dest {
dv := reflect.ValueOf(dest[i])
if dv.Kind() != reflect.Ptr {
return fmt.Errorf("dest[%d] is not pointer", i)
}
dv.Elem().Set(reflect.ValueOf(s.values[i]))
}
return nil
}
func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
int64(1), // id
int64(10), // user_id
int64(20), // api_key_id
int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model
sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id
1, // input_tokens
2, // output_tokens
3, // cache_creation_tokens
4, // cache_read_tokens
5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens
0.1, // input_cost
0.2, // output_cost
0.3, // cache_creation_cost
0.4, // cache_read_cost
1.0, // total_cost
0.9, // actual_cost
1.0, // rate_multiplier
sql.NullFloat64{}, // account_rate_multiplier
int16(service.BillingTypeBalance),
int16(service.RequestTypeWSV2),
false, // legacy stream
false, // legacy openai ws
sql.NullInt64{},
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
require.True(t, log.Stream)
require.True(t, log.OpenAIWSMode)
})
t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
int64(2),
int64(11),
int64(21),
int64(31),
sql.NullString{Valid: true, String: "req-2"},
"gpt-5",
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
int16(service.BillingTypeBalance),
int16(service.RequestTypeUnknown),
true,
false,
sql.NullInt64{},
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.Equal(t, service.RequestTypeStream, log.RequestType)
require.True(t, log.Stream)
require.False(t, log.OpenAIWSMode)
})
}

View File

@@ -0,0 +1,41 @@
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSafeDateFormat(t *testing.T) {
tests := []struct {
name string
granularity string
expected string
}{
// 合法值
{"hour", "hour", "YYYY-MM-DD HH24:00"},
{"day", "day", "YYYY-MM-DD"},
{"week", "week", "IYYY-IW"},
{"month", "month", "YYYY-MM"},
// 非法值回退到默认
{"空字符串", "", "YYYY-MM-DD"},
{"未知粒度 year", "year", "YYYY-MM-DD"},
{"未知粒度 minute", "minute", "YYYY-MM-DD"},
// 恶意字符串
{"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"},
{"带引号", "day'", "YYYY-MM-DD"},
{"带括号", "day)", "YYYY-MM-DD"},
{"Unicode", "日", "YYYY-MM-DD"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := safeDateFormat(tc.granularity)
require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity)
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type userGroupRateRepository struct {
@@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil
}
// GetByUserIDs 批量获取多个用户的专属分组倍率。
// 返回结构map[userID]map[groupID]rate
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 {
return result, nil
}
uniqueIDs := make([]int64, 0, len(userIDs))
seen := make(map[int64]struct{}, len(userIDs))
for _, userID := range userIDs {
if userID <= 0 {
continue
}
if _, exists := seen[userID]; exists {
continue
}
seen[userID] = struct{}{}
uniqueIDs = append(uniqueIDs, userID)
result[userID] = make(map[int64]float64)
}
if len(uniqueIDs) == 0 {
return result, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
WHERE user_id = ANY($1)
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var userID int64
var groupID int64
var rate float64
if err := rows.Scan(&userID, &groupID, &rate); err != nil {
return nil, err
}
if _, ok := result[userID]; !ok {
result[userID] = make(map[int64]float64)
}
result[userID][groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
@@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
toUpsert := make(map[int64]float64)
upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
toUpsert[groupID] = *rate
upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate)
}
}
// 删除指定的记录
for _, groupID := range toDelete {
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
userID, groupID)
if err != nil {
if len(toDelete) > 0 {
if _, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
userID, pq.Array(toDelete)); err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
for groupID, rate := range toUpsert {
if len(upsertGroupIDs) > 0 {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
`, userID, groupID, rate, now)
SELECT
$1::bigint,
data.group_id,
data.rate_multiplier,
$2::timestamptz,
$2::timestamptz
FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier)
ON CONFLICT (user_id, group_id)
DO UPDATE SET
rate_multiplier = EXCLUDED.rate_multiplier,
updated_at = EXCLUDED.updated_at
`, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates))
if err != nil {
return err
}

View File

@@ -0,0 +1,186 @@
package repository
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// Redis Key 模式(使用 hash tag 确保 Redis Cluster 下同一 accountID 的 key 落入同一 slot
// 格式: umq:{accountID}:lock / umq:{accountID}:last
const (
umqKeyPrefix = "umq:"
umqLockSuffix = ":lock" // STRING (requestID), PX lockTtlMs
umqLastSuffix = ":last" // STRING (毫秒时间戳), EX 60s
)
// Lua 脚本原子获取串行锁SET NX PX + 重入安全)
var acquireLockScript = redis.NewScript(`
local cur = redis.call('GET', KEYS[1])
if cur == ARGV[1] then
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[2]))
return 1
end
if cur ~= false then return 0 end
redis.call('SET', KEYS[1], ARGV[1], 'PX', tonumber(ARGV[2]))
return 1
`)
// Lua 脚本:原子释放锁 + 记录完成时间(使用 Redis TIME 避免时钟偏差)
var releaseLockScript = redis.NewScript(`
local cur = redis.call('GET', KEYS[1])
if cur == ARGV[1] then
redis.call('DEL', KEYS[1])
local t = redis.call('TIME')
local ms = tonumber(t[1])*1000 + math.floor(tonumber(t[2])/1000)
redis.call('SET', KEYS[2], ms, 'EX', 60)
return 1
end
return 0
`)
// Lua 脚本:原子清理孤儿锁(仅在 PTTL == -1 时删除,避免 TOCTOU 竞态误删合法锁)
var forceReleaseLockScript = redis.NewScript(`
local pttl = redis.call('PTTL', KEYS[1])
if pttl == -1 then
redis.call('DEL', KEYS[1])
return 1
end
return 0
`)
type userMsgQueueCache struct {
rdb *redis.Client
}
// NewUserMsgQueueCache 创建用户消息队列缓存
func NewUserMsgQueueCache(rdb *redis.Client) service.UserMsgQueueCache {
return &userMsgQueueCache{rdb: rdb}
}
func umqLockKey(accountID int64) string {
// 格式: umq:{123}:lock — 花括号确保 Redis Cluster hash tag 生效
return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLockSuffix
}
func umqLastKey(accountID int64) string {
// 格式: umq:{123}:last — 与 lockKey 同一 hash slot
return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLastSuffix
}
// umqScanPattern 用于 SCAN 扫描锁 key
func umqScanPattern() string {
return umqKeyPrefix + "{*}" + umqLockSuffix
}
// AcquireLock 尝试获取账号级串行锁
func (c *userMsgQueueCache) AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (bool, error) {
key := umqLockKey(accountID)
result, err := acquireLockScript.Run(ctx, c.rdb, []string{key}, requestID, lockTtlMs).Int()
if err != nil {
return false, fmt.Errorf("umq acquire lock: %w", err)
}
return result == 1, nil
}
// ReleaseLock 释放锁并记录完成时间
func (c *userMsgQueueCache) ReleaseLock(ctx context.Context, accountID int64, requestID string) (bool, error) {
lockKey := umqLockKey(accountID)
lastKey := umqLastKey(accountID)
result, err := releaseLockScript.Run(ctx, c.rdb, []string{lockKey, lastKey}, requestID).Int()
if err != nil {
return false, fmt.Errorf("umq release lock: %w", err)
}
return result == 1, nil
}
// GetLastCompletedMs 获取上次完成时间(毫秒时间戳)
func (c *userMsgQueueCache) GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) {
key := umqLastKey(accountID)
val, err := c.rdb.Get(ctx, key).Result()
if errors.Is(err, redis.Nil) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("umq get last completed: %w", err)
}
ms, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return 0, fmt.Errorf("umq parse last completed: %w", err)
}
return ms, nil
}
// ForceReleaseLock 原子清理孤儿锁(仅在 PTTL == -1 时删除,防止 TOCTOU 竞态误删合法锁)
func (c *userMsgQueueCache) ForceReleaseLock(ctx context.Context, accountID int64) error {
key := umqLockKey(accountID)
_, err := forceReleaseLockScript.Run(ctx, c.rdb, []string{key}).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("umq force release lock: %w", err)
}
return nil
}
// ScanLockKeys 扫描所有锁 key仅返回 PTTL == -1无过期时间的孤儿锁 accountID 列表
// 正常的锁都有 PX 过期时间PTTL == -1 表示异常状态(如 Redis 故障恢复后丢失 TTL
func (c *userMsgQueueCache) ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) {
var accountIDs []int64
var cursor uint64
pattern := umqScanPattern()
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return nil, fmt.Errorf("umq scan lock keys: %w", err)
}
for _, key := range keys {
// 检查 PTTL只清理 PTTL == -1无过期时间的异常锁
pttl, err := c.rdb.PTTL(ctx, key).Result()
if err != nil {
continue
}
// PTTL 返回值:-2 = key 不存在,-1 = 无过期时间,>0 = 剩余毫秒
// go-redis 对哨兵值 -1/-2 不乘精度系数,直接返回 time.Duration(-1)/-2
// 只删除 -1无过期时间的异常锁跳过正常持有的锁
if pttl != time.Duration(-1) {
continue
}
// 从 key 中提取 accountID: umq:{123}:lock → 提取 {} 内的数字
openBrace := strings.IndexByte(key, '{')
closeBrace := strings.IndexByte(key, '}')
if openBrace < 0 || closeBrace <= openBrace+1 {
continue
}
idStr := key[openBrace+1 : closeBrace]
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
continue
}
accountIDs = append(accountIDs, id)
if len(accountIDs) >= maxCount {
return accountIDs, nil
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return accountIDs, nil
}
// GetCurrentTimeMs 通过 Redis TIME 命令获取当前服务器时间(毫秒),确保与锁记录的时间源一致
func (c *userMsgQueueCache) GetCurrentTimeMs(ctx context.Context) (int64, error) {
t, err := c.rdb.Time(ctx).Result()
if err != nil {
return 0, fmt.Errorf("umq get redis time: %w", err)
}
return t.UnixMilli(), nil
}

View File

@@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
@@ -240,21 +243,24 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
// Batch load active subscriptions with groups to avoid N+1.
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDIn(userIDs...),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
).
WithGroup().
All(ctx)
if err != nil {
return nil, nil, err
}
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
if shouldLoadSubscriptions {
// Batch load active subscriptions with groups to avoid N+1.
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDIn(userIDs...),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
).
WithGroup().
All(ctx)
if err != nil {
return nil, nil, err
}
for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
}
}
}
@@ -363,10 +369,79 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil
}
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
WHERE id = $1
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
if err == nil {
return newUsed, nil
}
if errors.Is(err, sql.ErrNoRows) {
// 区分用户不存在和配额冲突
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
if existsErr != nil {
return 0, existsErr
}
if !exists {
return 0, service.ErrUserNotFound
}
return 0, service.ErrSoraStorageQuotaExceeded
}
return 0, err
}
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
WHERE id = $1
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes}, &newUsed)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
return 0, err
}
return newUsed, nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
return client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().

View File

@@ -28,13 +28,13 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
return NewGitHubReleaseClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL)
return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError)
}
// ProvideSessionLimitCache 创建会话限制缓存
@@ -53,12 +53,14 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewSoraAccountRepository, // Sora 账号扩展表仓储
NewProxyRepository,
NewRedeemCodeRepository,
NewPromoCodeRepository,
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,
NewSettingRepository,
@@ -77,6 +79,8 @@ var ProviderSet = wire.NewSet(
NewTimeoutCounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewRPMCache,
NewUserMsgQueueCache,
NewDashboardCache,
NewEmailCache,
NewIdentityCache,
@@ -104,6 +108,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient,
NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient,
NewGeminiDriveClient,
ProvideEnt,
ProvideSQLDB,