fix: round-2 audit fixes — security, code quality, and UI improvements

Security (HIGH):
- Normalize all Redis cache keys to lowercase (verifyCode, passwordReset)
- Fix verify code TTL renewal on failed attempts: use remaining TTL via
  ExpiresAt field instead of resetting to full 15-minute window
- Add 3 missing fields to diffSettings audit log (promo_code, invitation_code,
  custom_endpoints)

Code quality (MEDIUM):
- Extract filterVerifiedEmails shared helper (balance_notify_service.go)
- Add Pricing array non-empty validation for channel pricing rules
- Add platform token semantics comment in gateway_service.go
- Complete validatePlanPatch test coverage (+10 test cases)
- Replace string types with QuotaThresholdType/QuotaResetMode across frontend
- Remove duplicate getPlatformTextColor/getRateBadgeClass in ChannelsView
- Return EMAIL_NOT_FOUND error on RemoveNotifyEmail miss

UI improvements:
- Reorder cost tooltip: user billing above separator, account billing below
- Add NaN guard to accountBilled function
- Move timezone selector inline into reset-mode row (no longer standalone)
This commit is contained in:
erio
2026-04-14 00:26:20 +08:00
parent 74f8a30f86
commit a9880ee7b9
15 changed files with 605 additions and 291 deletions

View File

@@ -357,6 +357,11 @@ func (h *ChannelHandler) Create(c *gin.Context) {
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
@@ -420,6 +425,11 @@ func (h *ChannelHandler) Update(c *gin.Context) {
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)

View File

@@ -1138,6 +1138,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist")
}
if before.PromoCodeEnabled != after.PromoCodeEnabled {
changed = append(changed, "promo_code_enabled")
}
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
changed = append(changed, "invitation_code_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
@@ -1348,6 +1354,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items")
}
if before.CustomEndpoints != after.CustomEndpoints {
changed = append(changed, "custom_endpoints")
}
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
changed = append(changed, "enable_fingerprint_unification")
}

View File

@@ -20,8 +20,9 @@ const (
)
// verifyCodeKey generates the Redis key for email verification code.
// Email is lowercased for case-insensitive consistency.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
return verifyCodeKeyPrefix + strings.ToLower(email)
}
// notifyVerifyKey generates the Redis key for notify email verification code.
@@ -33,12 +34,12 @@ func notifyVerifyKey(email string) string {
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email
return passwordResetKeyPrefix + strings.ToLower(email)
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email
return passwordResetSentAtKeyPrefix + strings.ToLower(email)
}
type emailCache struct {

View File

@@ -283,6 +283,20 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context)
return nil
}
return filterVerifiedEmails(entries)
}
// getSiteName reads site name from settings with fallback.
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
if err != nil || name == "" {
return defaultSiteName
}
return name
}
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
func filterVerifiedEmails(entries []NotifyEmailEntry) []string {
var recipients []string
seen := make(map[string]bool)
for _, entry := range entries {
@@ -303,38 +317,10 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context)
return recipients
}
// getSiteName reads site name from settings with fallback.
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
if err != nil || name == "" {
return defaultSiteName
}
return name
}
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
// Only emails with verified=true and disabled=false are included.
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
var recipients []string
seen := make(map[string]bool)
for _, entry := range user.BalanceNotifyExtraEmails {
if entry.Disabled || !entry.Verified {
continue
}
email := strings.TrimSpace(entry.Email)
if email == "" {
continue
}
lower := strings.ToLower(email)
if seen[lower] {
continue
}
seen[lower] = true
recipients = append(recipients, email)
}
return recipients
return filterVerifiedEmails(user.BalanceNotifyExtraEmails)
}
// sendEmails sends an email to all recipients with shared timeout and error logging.

View File

@@ -55,6 +55,7 @@ type VerificationCodeData struct {
Code string
Attempts int
CreatedAt time.Time
ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts
}
// PasswordResetTokenData represents password reset token data
@@ -263,6 +264,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
}
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
@@ -295,7 +297,11 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update verification attempt count", "email", email, "error", err)
}
if data.Attempts >= maxVerifyCodeAttempts {

View File

@@ -1194,12 +1194,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
@@ -1275,11 +1283,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择
}
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
}
// 对于等待计划的情况,也需要先检查会话限制
@@ -1291,26 +1295,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
})
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
})
}
}
@@ -1433,36 +1431,39 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if s.isAccountSchedulableForSelection(stickyAccount) &&
var stickyCacheMissReason string
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
if rpmPass { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位
stickyCacheMissReason = "session_limit"
// 继续到负载感知选择
} else {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
return &AccountSelectionResult{
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil)
}
}
if stickyCacheMissReason == "" {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
stickyCacheMissReason = "session_limit"
// 会话限制已满,继续到负载感知选择
} else {
return &AccountSelectionResult{
@@ -1475,11 +1476,31 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
}, nil
}
} else {
stickyCacheMissReason = "wait_queue_full"
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
} else if !gatePass {
stickyCacheMissReason = "gate_check"
} else {
stickyCacheMissReason = "rpm_red"
}
// 记录粘性缓存未命中的结构化日志
if stickyCacheMissReason != "" {
baseRPM := stickyAccount.GetBaseRPM()
var currentRPM int
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
currentRPM = count
}
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
}
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
stickyAccountID, shortSessionHash(sessionHash))
}
}
}
@@ -1544,11 +1565,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil)
}
}
@@ -1561,15 +1578,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: item.account,
WaitPlan: &AccountWaitPlan{
return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{
AccountID: item.account.ID,
MaxConcurrency: item.account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
})
}
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
@@ -1603,11 +1617,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
if s.cache != nil {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
}
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
}
}
@@ -1617,15 +1630,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
} else {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
})
}
}
}
@@ -1684,7 +1694,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil {
return nil, legacyErr
} else if ok {
return result, nil
}
} else {
@@ -1723,11 +1735,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil)
}
}
@@ -1750,20 +1758,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号
}
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
})
}
return nil, ErrNoAvailableAccounts
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
@@ -1778,15 +1783,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil)
if err != nil {
return nil, false, err
}
return selection, true, nil
}
}
return nil, false
return nil, false, nil
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
@@ -2401,6 +2406,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID)
}
func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
}
if hydrated == nil {
return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID)
}
return hydrated, nil
}
func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
hydrated, err := s.hydrateSelectedAccount(ctx, account)
if err != nil {
return nil, err
}
return &AccountSelectionResult{
Account: hydrated,
Acquired: acquired,
ReleaseFunc: release,
WaitPlan: waitPlan,
}, nil
}
// filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
@@ -2676,6 +2708,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account
accountsLoaded := false
@@ -2747,6 +2785,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
@@ -2852,6 +2896,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
@@ -2918,6 +2968,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account
accountsLoaded := false
@@ -2985,6 +3041,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
@@ -3078,6 +3140,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -3090,6 +3153,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
@@ -3257,8 +3326,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "excluded"}
}
if !s.isAccountSchedulableForSelection(acc) {
detail := "generic_unschedulable"
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
}
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
return selectionFailureDiagnosis{
@@ -3282,7 +3350,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"}
}
// GetAccessToken 获取账号凭证
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil {
return true
@@ -3653,6 +3720,86 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
return result
}
// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages
// system 字段仅保留 Claude Code 标识提示词。
// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词
// 无法通过检测,因为后续内容仍为非 Claude Code 格式。
// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对system 仅保留 Claude Code 标识。
func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
system = normalizeSystemParam(system)
// 1. 提取原始 system prompt 文本
var originalSystemText string
switch v := system.(type) {
case string:
originalSystemText = strings.TrimSpace(v)
case []any:
var parts []string
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, text)
}
}
}
originalSystemText = strings.Join(parts, "\n\n")
}
// 2. 将 system 替换为 Claude Code 标准提示词array 格式,与真实 Claude Code 一致)
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
// 使用 string 格式会被 Anthropic 检测为第三方应用。
claudeCodeSystemBlock := []map[string]any{
{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
},
}
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
return body
}
// 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头
// 模型仍通过 messages 接收完整指令,保留客户端功能
ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
instrMsg, err1 := json.Marshal(map[string]any{
"role": "user",
"content": []map[string]any{
{"type": "text", "text": "[System Instructions]\n" + originalSystemText},
},
})
ackMsg, err2 := json.Marshal(map[string]any{
"role": "assistant",
"content": []map[string]any{
{"type": "text", "text": "Understood. I will follow these instructions."},
},
})
if err1 != nil || err2 != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection")
return out
}
// 重建 messages 数组:[instruction, ack, ...originalMessages]
items := [][]byte{instrMsg, ackMsg}
messagesResult := gjson.GetBytes(out, "messages")
if messagesResult.IsArray() {
messagesResult.ForEach(func(_, msg gjson.Result) bool {
items = append(items, []byte(msg.Raw))
return true
})
}
if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk {
out = next
}
}
return out
}
type cacheControlPath struct {
path string
log string
@@ -3819,7 +3966,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil {
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model)
if policy.blockErr != nil {
return nil, policy.blockErr
}
@@ -3849,19 +3996,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
// 条件1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
systemRewritten := false
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System)
body = rewriteSystemForNonClaudeCode(body, parsed.System)
systemRewritten = true
}
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
// system 被重写时保留 CC prompt 的 cache_control: ephemeral匹配真实 Claude Code 行为);
// 未重写时haiku / 已含 CC 前缀)剥离客户端 cache_control与原有行为一致。
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
// metadata 透传开启时跳过 metadata 注入
_, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx)
_, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx)
if !mimicMPT {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true
@@ -5407,9 +5559,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号应用统一指纹和metadata重写受设置开关控制
var fingerprint *Fingerprint
enableFP, enableMPT := true, false
enableFP, enableMPT, enableCCH := true, false, false
if s.settingService != nil {
enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
}
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
@@ -5436,6 +5588,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
if fingerprint != nil {
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
}
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
if enableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -5476,9 +5637,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
effectiveDropSet := mergeDropSets(policyFilterSet)
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
// 处理 anthropic-beta headerOAuth 账号需要包含 oauth beta
if tokenType == "oauth" {
@@ -5489,11 +5649,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
// Claude Code OAuth credentials are scoped to Claude Code.
// Non-haiku models MUST include claude-code beta for Anthropic to recognize
// this as a legitimate Claude Code request; without it, the request is
// rejected as third-party ("out of extra usage").
// Haiku models are exempt from third-party detection and don't need it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
} else {
// Claude Code 客户端:尽量透传原始 header仅补齐 oauth beta
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
@@ -5716,7 +5881,7 @@ type betaPolicyResult struct {
}
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult {
if s.settingService == nil {
return betaPolicyResult{}
}
@@ -5731,10 +5896,11 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
switch rule.Action {
effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
switch effectiveAction {
case BetaPolicyActionBlock:
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
msg := rule.ErrorMessage
msg := effectiveErrMsg
if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed"
}
@@ -5776,7 +5942,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet"
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
// evaluates on demand (one DB call).
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} {
if c != nil {
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
if fs, ok := v.(map[string]struct{}); ok {
@@ -5784,7 +5950,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
}
}
}
return s.evaluateBetaPolicy(ctx, "", account).filterSet
return s.evaluateBetaPolicy(ctx, "", account, model).filterSet
}
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
@@ -5803,6 +5969,33 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
}
}
// matchModelWhitelist checks if a model matches any pattern in the whitelist.
// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching.
func matchModelWhitelist(model string, whitelist []string) bool {
for _, pattern := range whitelist {
if matchModelPattern(pattern, model) {
return true
}
}
return false
}
// resolveRuleAction determines the effective action and error message for a rule given the request model.
// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally.
// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others.
func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) {
if len(rule.ModelWhitelist) == 0 {
return rule.Action, rule.ErrorMessage
}
if matchModelWhitelist(model, rule.ModelWhitelist) {
return rule.Action, rule.ErrorMessage
}
if rule.FallbackAction != "" {
return rule.FallbackAction, rule.FallbackErrorMessage
}
return BetaPolicyActionPass, "" // default fallback: pass (fail-open)
}
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
func droppedBetaSet(extra ...string) map[string]struct{} {
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
@@ -5849,7 +6042,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
modelID string,
) ([]string, error) {
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
policy := s.evaluateBetaPolicy(ctx, betaHeader, account)
policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID)
if policy.blockErr != nil {
return nil, policy.blockErr
}
@@ -5861,7 +6054,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
// 例如:管理员 block 了 interleaved-thinking客户端不在 header 中带该 token
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
// 如果不做此检查block 规则会被绕过。
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil {
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil {
return nil, blockErr
}
@@ -5870,7 +6063,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError {
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError {
if s.settingService == nil || len(tokens) == 0 {
return nil
}
@@ -5882,14 +6075,15 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
isBedrock := account.IsBedrock()
tokenSet := buildBetaTokenSet(tokens)
for _, rule := range settings.Rules {
if rule.Action != BetaPolicyActionBlock {
effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
if effectiveAction != BetaPolicyActionBlock {
continue
}
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
if _, present := tokenSet[rule.BetaToken]; present {
msg := rule.ErrorMessage
msg := effectiveErrMsg
if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed"
}
@@ -7146,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
// - API Key 限速用量更新
// - 账号配额用量更新账号口径TotalCost × 账号计费倍率)
// postUsageBilling is the legacy fallback billing path used when the unified
// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply
// for atomic billing. This path only runs in tests or degraded mode.
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
billingCtx, cancel := detachedBillingContext(ctx)
defer cancel()
cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill {
if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
}
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
}
}
// 2. API Key 配额
if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
@@ -7196,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
}
finalizePostUsageBilling(p, deps)
// NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
// cache updates. The legacy path does DB writes directly; the finalize path
// does cache queue + notifications. Notifications are dispatched separately
// by the caller after recording the usage log.
}
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
@@ -7250,9 +7439,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
cmd.CacheReadTokens = usageLog.CacheReadTokens
cmd.ImageCount = usageLog.ImageCount
if usageLog.MediaType != nil {
cmd.MediaType = *usageLog.MediaType
}
if usageLog.ServiceTier != nil {
cmd.ServiceTier = *usageLog.ServiceTier
}
@@ -7315,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
}
}
finalizePostUsageBilling(p, deps)
finalizePostUsageBilling(p, deps, result)
return true, nil
}
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
if p == nil || p.Cost == nil || deps == nil {
return
}
@@ -7338,22 +7524,82 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
// Balance low notification — use real-time balance from billing cache (not stale snapshot)
if !p.IsSubscriptionBill && p.Cost.ActualCost > 0 && p.User != nil && deps.balanceNotifyService != nil {
oldBalance := p.User.Balance // fallback to snapshot
if deps.billingCacheService != nil {
if realBalance, err := deps.billingCacheService.GetUserBalance(context.Background(), p.User.ID); err == nil {
oldBalance = realBalance + p.Cost.ActualCost // DB already deducted, reconstruct pre-deduction balance
// Notification checks run async — all parameters are already captured,
// no dependency on the request context or upstream connection.
go notifyBalanceLow(p, deps, result)
go notifyAccountQuota(p, deps, result)
}
// notifyBalanceLow sends balance low notification after deduction.
// When result.NewBalance is available (from DB transaction RETURNING), it is used directly
// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races.
func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyBalanceLow", "recover", r)
}
}
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
}()
if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil {
slog.Debug("notifyBalanceLow: skipped",
"is_subscription", p.IsSubscriptionBill,
"actual_cost", p.Cost.ActualCost,
"user_nil", p.User == nil,
"service_nil", deps.balanceNotifyService == nil,
)
return
}
// Account quota notification (use same cost formula as postUsageBilling)
if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil {
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost)
oldBalance := resolveOldBalance(p, result)
slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction",
"user_id", p.User.ID,
"old_balance", oldBalance,
"cost", p.Cost.ActualCost,
"notify_enabled", p.User.BalanceNotifyEnabled,
"threshold", p.User.BalanceNotifyThreshold,
"result_has_new_balance", result != nil && result.NewBalance != nil,
)
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
}
// resolveOldBalance returns the pre-deduction balance.
// Prefers the DB transaction result (newBalance + cost) over snapshot.
func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 {
if result != nil && result.NewBalance != nil {
return *result.NewBalance + p.Cost.ActualCost
}
// Legacy fallback: snapshot balance from request context
return p.User.Balance
}
// notifyAccountQuota sends account quota threshold notification after increment.
// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly
// to avoid a separate DB read that may see stale or concurrently-modified data.
func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyAccountQuota", "recover", r)
}
}()
if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil {
slog.Debug("notifyAccountQuota: skipped",
"total_cost", p.Cost.TotalCost,
"account_nil", p.Account == nil,
"is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(),
"service_nil", deps.balanceNotifyService == nil,
)
return
}
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
var quotaState *AccountQuotaState
if result != nil {
quotaState = result.QuotaState
}
slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement",
"account_id", p.Account.ID,
"account_cost", accountCost,
"has_quota_state", quotaState != nil,
)
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState)
}
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
@@ -7422,11 +7668,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct {
// ParsedRequest可选仅 Claude 路径传入)
// Claude Max 策略所需的 ParsedRequest可选仅 Claude 路径传入)
ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - MediaType 字段写入使用日志
// - Claude Max 缓存计费策略
EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要)
@@ -7451,7 +7697,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
ParsedRequest: input.ParsedRequest,
EnableClaudePath: true,
})
}
@@ -7517,6 +7762,7 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
@@ -7583,13 +7829,10 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
upstreamModel := result.UpstreamModel
if upstreamModel == "" {
upstreamModel = result.Model
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, upstreamModel,
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
// Anthropic's input_tokens excludes cache_read and cache_creation (billed separately);
// OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason.
UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
@@ -7597,7 +7840,6 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
},
1, // requestCount
cost.TotalCost,
)
}
@@ -7796,13 +8038,12 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
BillingMode: resolveBillingMode(opts, result, cost),
BillingMode: resolveBillingMode(result, cost),
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
@@ -7826,7 +8067,7 @@ func (s *GatewayService) buildRecordUsageLog(
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
var mode string
switch {
case cost != nil && cost.BillingMode != "":
@@ -7839,10 +8080,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
return &mode
}
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
return nil
}
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil {
return &subscription.ID
@@ -8349,9 +8586,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用统一指纹和重写 userID受设置开关控制
// 如果启用了会话ID伪装会在重写后替换 session 部分为固定值
ctEnableFP, ctEnableMPT := true, false
ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
if s.settingService != nil {
ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
}
var ctFingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
@@ -8369,6 +8606,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
if ctFingerprint != nil && ctEnableFP {
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
}
if ctEnableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -8409,7 +8654,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {

View File

@@ -128,3 +128,66 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
require.NoError(t, err)
}
// --- validatePlanPatch: other fields ---
func ptrStr(s string) *string { return &s }
func ptrInt(i int) *int { return &i }
func ptrInt64(i int64) *int64 { return &i }
func ptrFloat(f float64) *float64 { return &f }
func TestValidatePlanPatch_EmptyName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanPatch_ValidName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroGroupID(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanPatch_NegativePrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ZeroPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ValidPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")})
require.NoError(t, err)
}
func TestValidatePlanPatch_AllNil(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{})
require.NoError(t, err)
}

View File

@@ -330,6 +330,7 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
@@ -370,7 +371,11 @@ func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string)
}
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update notify verify code attempts", "email", email, "error", err)
}
if data.Attempts >= maxVerifyCodeAttempts {
@@ -418,11 +423,17 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email
}
filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
found := false
for _, e := range user.BalanceNotifyExtraEmails {
if !strings.EqualFold(e.Email, email) {
if strings.EqualFold(e.Email, email) {
found = true
} else {
filtered = append(filtered, e)
}
}
if !found {
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
}
user.BalanceNotifyExtraEmails = filtered
return s.userRepo.Update(ctx, user)
}

View File

@@ -1,6 +1,7 @@
<script setup lang="ts">
import { useI18n } from 'vue-i18n'
import QuotaNotifyToggle from './QuotaNotifyToggle.vue'
import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account'
const { t } = useI18n()
@@ -11,9 +12,9 @@ const props = defineProps<{
quotaNotifyGlobalEnabled: boolean
notifyEnabled: boolean | null
notifyThreshold: number | null
notifyThresholdType: string | null
notifyThresholdType: QuotaThresholdType | null
// Reset mode (only for daily/weekly, null for total)
resetMode: 'rolling' | 'fixed' | null
resetMode: QuotaResetMode | null
resetHour: number | null
resetDay: number | null // weekly only
resetTimezone: string | null
@@ -22,14 +23,15 @@ const props = defineProps<{
// Shared options passed from parent
hourOptions: number[]
dayOptions: { value: number; key: string }[]
timezoneOptions?: string[]
}>()
const emit = defineEmits<{
'update:limit': [value: number | null]
'update:notifyEnabled': [value: boolean | null]
'update:notifyThreshold': [value: number | null]
'update:notifyThresholdType': [value: string | null]
'update:resetMode': [value: 'rolling' | 'fixed' | null]
'update:notifyThresholdType': [value: QuotaThresholdType | null]
'update:resetMode': [value: QuotaResetMode | null]
'update:resetHour': [value: number | null]
'update:resetDay': [value: number | null]
'update:resetTimezone': [value: string | null]
@@ -43,7 +45,7 @@ const onLimitInput = (e: Event) => {
}
const onModeChange = (e: Event) => {
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
const val = (e.target as HTMLSelectElement).value as QuotaResetMode
emit('update:resetMode', val)
if (val === 'fixed') {
if (props.resetHour == null) emit('update:resetHour', 0)
@@ -51,6 +53,17 @@ const onModeChange = (e: Event) => {
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
}
}
function getTimezoneOffsetLabel(tz: string): string {
try {
const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' })
const parts = dtf.formatToParts(new Date())
const tzPart = parts.find(p => p.type === 'timeZoneName')
return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : ''
} catch {
return ''
}
}
</script>
<template>
@@ -95,6 +108,11 @@ const onModeChange = (e: Event) => {
<select :value="resetHour ?? 0" @change="emit('update:resetHour', Number(($event.target as HTMLSelectElement).value))" class="input py-1 text-xs w-24">
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
</select>
<template v-if="timezoneOptions && timezoneOptions.length > 0">
<select :value="resetTimezone || 'UTC'" @change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)" class="input py-1 text-xs w-auto">
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }} ({{ getTimezoneOffsetLabel(tz) }})</option>
</select>
</template>
</template>
<span class="text-[11px] text-gray-500 dark:text-gray-400">
<template v-if="resetMode === 'fixed'">{{ hintFixed }}</template>

View File

@@ -2,6 +2,7 @@
import { ref, watch, computed } from 'vue'
import { useI18n } from 'vue-i18n'
import QuotaDimensionRow from './QuotaDimensionRow.vue'
import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account'
const { t } = useI18n()
@@ -9,22 +10,22 @@ const props = withDefaults(defineProps<{
totalLimit: number | null
dailyLimit: number | null
weeklyLimit: number | null
dailyResetMode: 'rolling' | 'fixed' | null
dailyResetMode: QuotaResetMode | null
dailyResetHour: number | null
weeklyResetMode: 'rolling' | 'fixed' | null
weeklyResetMode: QuotaResetMode | null
weeklyResetDay: number | null
weeklyResetHour: number | null
resetTimezone: string | null
quotaNotifyGlobalEnabled?: boolean
quotaNotifyDailyEnabled?: boolean | null
quotaNotifyDailyThreshold?: number | null
quotaNotifyDailyThresholdType?: string | null
quotaNotifyDailyThresholdType?: QuotaThresholdType | null
quotaNotifyWeeklyEnabled?: boolean | null
quotaNotifyWeeklyThreshold?: number | null
quotaNotifyWeeklyThresholdType?: string | null
quotaNotifyWeeklyThresholdType?: QuotaThresholdType | null
quotaNotifyTotalEnabled?: boolean | null
quotaNotifyTotalThreshold?: number | null
quotaNotifyTotalThresholdType?: string | null
quotaNotifyTotalThresholdType?: QuotaThresholdType | null
}>(), {
quotaNotifyGlobalEnabled: false,
quotaNotifyDailyEnabled: null,
@@ -42,21 +43,21 @@ const emit = defineEmits<{
'update:totalLimit': [value: number | null]
'update:dailyLimit': [value: number | null]
'update:weeklyLimit': [value: number | null]
'update:dailyResetMode': [value: 'rolling' | 'fixed' | null]
'update:dailyResetMode': [value: QuotaResetMode | null]
'update:dailyResetHour': [value: number | null]
'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null]
'update:weeklyResetMode': [value: QuotaResetMode | null]
'update:weeklyResetDay': [value: number | null]
'update:weeklyResetHour': [value: number | null]
'update:resetTimezone': [value: string | null]
'update:quotaNotifyDailyEnabled': [value: boolean | null]
'update:quotaNotifyDailyThreshold': [value: number | null]
'update:quotaNotifyDailyThresholdType': [value: string | null]
'update:quotaNotifyDailyThresholdType': [value: QuotaThresholdType | null]
'update:quotaNotifyWeeklyEnabled': [value: boolean | null]
'update:quotaNotifyWeeklyThreshold': [value: number | null]
'update:quotaNotifyWeeklyThresholdType': [value: string | null]
'update:quotaNotifyWeeklyThresholdType': [value: QuotaThresholdType | null]
'update:quotaNotifyTotalEnabled': [value: boolean | null]
'update:quotaNotifyTotalThreshold': [value: number | null]
'update:quotaNotifyTotalThresholdType': [value: string | null]
'update:quotaNotifyTotalThresholdType': [value: QuotaThresholdType | null]
}>()
const enabled = computed(() =>
@@ -89,11 +90,6 @@ watch(localEnabled, (val) => {
}
})
// Whether any fixed mode is active (to show timezone selector)
const hasFixedMode = computed(() =>
props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed'
)
// Common timezone options
const timezoneOptions = [
'UTC', 'Asia/Shanghai', 'Asia/Tokyo', 'Asia/Seoul', 'Asia/Singapore', 'Asia/Kolkata',
@@ -102,18 +98,6 @@ const timezoneOptions = [
'America/Sao_Paulo', 'Australia/Sydney', 'Pacific/Auckland',
]
// Compute GMT offset label (e.g. "GMT+8", "GMT-5") for a given IANA timezone.
function getTimezoneOffsetLabel(tz: string): string {
try {
const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' })
const parts = dtf.formatToParts(new Date())
const tzPart = parts.find(p => p.type === 'timeZoneName')
return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : ''
} catch {
return ''
}
}
// Hours for dropdown (0-23)
const hourOptions = Array.from({ length: 24 }, (_, i) => i)
@@ -197,6 +181,7 @@ const dailyFixedHint = computed(() =>
:hint-fixed="dailyFixedHint"
:hour-options="hourOptions"
:day-options="dayOptions"
:timezone-options="timezoneOptions"
@update:limit="emit('update:dailyLimit', $event)"
@update:notify-enabled="emit('update:quotaNotifyDailyEnabled', $event)"
@update:notify-threshold="emit('update:quotaNotifyDailyThreshold', $event)"
@@ -223,6 +208,7 @@ const dailyFixedHint = computed(() =>
:hint-fixed="weeklyFixedHint"
:hour-options="hourOptions"
:day-options="dayOptions"
:timezone-options="timezoneOptions"
@update:limit="emit('update:weeklyLimit', $event)"
@update:notify-enabled="emit('update:quotaNotifyWeeklyEnabled', $event)"
@update:notify-threshold="emit('update:quotaNotifyWeeklyThreshold', $event)"
@@ -233,14 +219,6 @@ const dailyFixedHint = computed(() =>
@update:reset-timezone="emit('update:resetTimezone', $event)"
/>
<!-- Timezone selector (shared by daily/weekly when fixed mode is active) -->
<div v-if="hasFixedMode">
<label class="input-label">{{ t('admin.accounts.quotaResetTimezone') }}</label>
<select :value="resetTimezone || 'UTC'" @change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)" class="input text-sm">
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }} ({{ getTimezoneOffsetLabel(tz) }})</option>
</select>
</div>
<!-- Total quota -->
<QuotaDimensionRow
dim="total"

View File

@@ -1,16 +1,16 @@
<script setup lang="ts">
import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE } from '@/constants/account'
import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE, type QuotaThresholdType } from '@/constants/account'
defineProps<{
enabled: boolean | null
threshold: number | null
thresholdType: string | null // "fixed" (default) or "percentage"
thresholdType: QuotaThresholdType | null
}>()
const emit = defineEmits<{
'update:enabled': [value: boolean | null]
'update:threshold': [value: number | null]
'update:thresholdType': [value: string | null]
'update:thresholdType': [value: QuotaThresholdType | null]
}>()
</script>
@@ -43,7 +43,7 @@ const emit = defineEmits<{
/>
<select
:value="thresholdType || QUOTA_THRESHOLD_TYPE_FIXED"
@change="emit('update:thresholdType', ($event.target as HTMLSelectElement).value)"
@change="emit('update:thresholdType', ($event.target as HTMLSelectElement).value as QuotaThresholdType)"
class="input py-1 text-xs w-[4.5rem] flex-shrink-0 text-center"
>
<option :value="QUOTA_THRESHOLD_TYPE_FIXED">$</option>

View File

@@ -313,10 +313,6 @@
<span class="text-gray-400">{{ t('usage.rate') }}</span>
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span>
</div>
<div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
</div>
<div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.original') }}</span>
<span class="font-medium text-white">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
@@ -325,7 +321,12 @@
<span class="text-gray-400">{{ t('usage.userBilled') }}</span>
<span class="font-semibold text-green-400">${{ tooltipData?.actual_cost?.toFixed(6) || '0.000000' }}</span>
</div>
<!-- Account billing (separated from user billing) -->
<div class="flex items-center justify-between gap-6 border-t border-gray-700 pt-1.5">
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
</div>
<div class="flex items-center justify-between gap-6">
<span class="text-gray-400">{{ t('usage.accountBilled') }}</span>
<span class="font-semibold text-green-400">
${{ accountBilled({
@@ -355,7 +356,8 @@ import { getBillingModeLabel, getBillingModeBadgeClass, BILLING_MODE_TOKEN, BILL
/** Compute the account-billed cost for display: (account_stats_cost ?? total_cost) * rate_multiplier */
function accountBilled(row: { total_cost?: number | null; account_stats_cost?: number | null; account_rate_multiplier?: number | null }): number {
const base = row.account_stats_cost != null ? row.account_stats_cost : (row.total_cost ?? 0)
return base * (row.account_rate_multiplier ?? 1)
const result = base * (row.account_rate_multiplier ?? 1)
return Number.isNaN(result) ? 0 : result
}
import DataTable from '@/components/common/DataTable.vue'

View File

@@ -1,6 +1,6 @@
import { reactive, ref } from 'vue'
import { adminAPI } from '@/api/admin'
import { QUOTA_THRESHOLD_TYPE_FIXED } from '@/constants/account'
import { QUOTA_THRESHOLD_TYPE_FIXED, type QuotaThresholdType } from '@/constants/account'
export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const
export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
@@ -8,7 +8,7 @@ export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
interface DimState {
enabled: boolean | null
threshold: number | null
thresholdType: string | null
thresholdType: QuotaThresholdType | null
}
export function useQuotaNotifyState() {
@@ -34,7 +34,7 @@ export function useQuotaNotifyState() {
for (const d of QUOTA_NOTIFY_DIMS) {
state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null
state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null
state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as string) ?? null
state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as QuotaThresholdType) ?? null
}
}

View File

@@ -8,3 +8,8 @@ export type WebSearchMode = typeof WEB_SEARCH_MODE_DEFAULT | typeof WEB_SEARCH_M
export const QUOTA_THRESHOLD_TYPE_FIXED = 'fixed' as const
export const QUOTA_THRESHOLD_TYPE_PERCENTAGE = 'percentage' as const
export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOTA_THRESHOLD_TYPE_PERCENTAGE
/** Quota reset mode values */
export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const
export const QUOTA_RESET_MODE_FIXED = 'fixed' as const
export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED

View File

@@ -166,8 +166,8 @@
class="channel-tab group"
:class="activeTab === section.platform ? 'channel-tab-active' : 'channel-tab-inactive'"
>
<PlatformIcon :platform="section.platform" size="xs" :class="getPlatformTextColor(section.platform)" />
<span :class="getPlatformTextColor(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span>
<PlatformIcon :platform="section.platform" size="xs" :class="platformTextClass(section.platform)" />
<span :class="platformTextClass(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span>
</button>
</div>
@@ -246,8 +246,8 @@
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
@change="togglePlatform(p)"
/>
<PlatformIcon :platform="p" size="xs" :class="getPlatformTextColor(p)" />
<span :class="getPlatformTextColor(p)">{{ t('admin.groups.platforms.' + p, p) }}</span>
<PlatformIcon :platform="p" size="xs" :class="platformTextClass(p)" />
<span :class="platformTextClass(p)">{{ t('admin.groups.platforms.' + p, p) }}</span>
</label>
</div>
</div>
@@ -310,9 +310,9 @@
class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
@change="toggleGroupInSection(sIdx, group.id)"
/>
<span :class="['font-medium', getPlatformTextColor(group.platform)]">{{ group.name }}</span>
<span :class="['font-medium', platformTextClass(group.platform)]">{{ group.name }}</span>
<span
:class="['rounded-full px-1 py-0 text-[10px]', getRateBadgeClass(group.platform)]"
:class="['rounded-full px-1 py-0 text-[10px]', platformBadgeLightClass(group.platform)]"
>{{ group.rate_multiplier }}x</span>
<span class="text-[10px] text-gray-400">{{ group.account_count || 0 }}</span>
<span
@@ -363,7 +363,7 @@
:value="srcModel"
type="text"
class="input flex-1 text-xs"
:class="getPlatformTextColor(section.platform)"
:class="platformTextClass(section.platform)"
:placeholder="t('admin.channels.form.mappingSource', 'Source model')"
@change="renameMappingKey(sIdx, srcModel, ($event.target as HTMLInputElement).value)"
/>
@@ -372,7 +372,7 @@
:value="section.model_mapping[srcModel]"
type="text"
class="input flex-1 text-xs"
:class="getPlatformTextColor(section.platform)"
:class="platformTextClass(section.platform)"
:placeholder="t('admin.channels.form.mappingTarget', 'Target model')"
@input="section.model_mapping[srcModel] = ($event.target as HTMLInputElement).value"
/>
@@ -464,7 +464,7 @@
: 'border-gray-200 hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700'"
>
<input type="checkbox" :checked="rule.group_ids.includes(gid)" class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500" @change="rule.group_ids.includes(gid) ? rule.group_ids.splice(rule.group_ids.indexOf(gid), 1) : rule.group_ids.push(gid)" />
<span>{{ getGroupNameById(gid) }}</span>
<span :class="['font-medium', platformTextClass(section.platform)]">{{ getGroupNameById(gid) }}</span>
</label>
</div>
<p v-if="section.group_ids.length === 0" class="mt-1 text-xs text-gray-400">
@@ -481,7 +481,7 @@
:key="accountId"
class="inline-flex items-center gap-1 rounded-md border border-primary-300 bg-primary-50 px-2 py-0.5 text-xs dark:border-primary-700 dark:bg-primary-900/20"
>
<span>{{ getRuleAccountLabel(accountId) }}</span>
<span :class="['font-medium', platformTextClass(section.platform)]">{{ getRuleAccountLabel(accountId) }}</span>
<button type="button" @click="removeRuleAccount(rule, accountId)" class="text-gray-400 hover:text-red-500">
<Icon name="x" size="xs" />
</button>
@@ -595,7 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types'
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
import type { AdminGroup, GroupPlatform } from '@/types'
import type { Column } from '@/components/common/types'
import { platformTextClass } from '@/utils/platformColors'
import { platformTextClass, platformBadgeLightClass } from '@/utils/platformColors'
import AppLayout from '@/components/layout/AppLayout.vue'
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
import DataTable from '@/components/common/DataTable.vue'
@@ -720,26 +720,6 @@ let abortController: AbortController | null = null
// ── Platform config ──
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity']
function getPlatformTextColor(platform: string): string {
switch (platform) {
case 'anthropic': return 'text-orange-600 dark:text-orange-400'
case 'openai': return 'text-emerald-600 dark:text-emerald-400'
case 'gemini': return 'text-blue-600 dark:text-blue-400'
case 'antigravity': return 'text-purple-600 dark:text-purple-400'
default: return 'text-gray-600 dark:text-gray-400'
}
}
function getRateBadgeClass(platform: string): string {
switch (platform) {
case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
}
}
// ── Helpers ──
function formatDate(value: string): string {
if (!value) return '-'