mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
fix: round-2 audit fixes — security, code quality, and UI improvements
Security (HIGH): - Normalize all Redis cache keys to lowercase (verifyCode, passwordReset) - Fix verify code TTL renewal on failed attempts: use remaining TTL via ExpiresAt field instead of resetting to full 15-minute window - Add 3 missing fields to diffSettings audit log (promo_code, invitation_code, custom_endpoints) Code quality (MEDIUM): - Extract filterVerifiedEmails shared helper (balance_notify_service.go) - Add Pricing array non-empty validation for channel pricing rules - Add platform token semantics comment in gateway_service.go - Complete validatePlanPatch test coverage (+10 test cases) - Replace string types with QuotaThresholdType/QuotaResetMode across frontend - Remove duplicate getPlatformTextColor/getRateBadgeClass in ChannelsView - Return EMAIL_NOT_FOUND error on RemoveNotifyEmail miss UI improvements: - Reorder cost tooltip: user billing above separator, account billing below - Add NaN guard to accountBilled function - Move timezone selector inline into reset-mode row (no longer standalone)
This commit is contained in:
@@ -357,6 +357,11 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
|
||||
return
|
||||
}
|
||||
if len(r.Pricing) == 0 {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
|
||||
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
|
||||
return
|
||||
}
|
||||
rule := accountStatsPricingRuleRequestToService(r)
|
||||
rule.SortOrder = i
|
||||
statsRules = append(statsRules, rule)
|
||||
@@ -420,6 +425,11 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
|
||||
return
|
||||
}
|
||||
if len(r.Pricing) == 0 {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
|
||||
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
|
||||
return
|
||||
}
|
||||
rule := accountStatsPricingRuleRequestToService(r)
|
||||
rule.SortOrder = i
|
||||
statsRules = append(statsRules, rule)
|
||||
|
||||
@@ -1138,6 +1138,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
||||
changed = append(changed, "registration_email_suffix_whitelist")
|
||||
}
|
||||
if before.PromoCodeEnabled != after.PromoCodeEnabled {
|
||||
changed = append(changed, "promo_code_enabled")
|
||||
}
|
||||
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
|
||||
changed = append(changed, "invitation_code_enabled")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
@@ -1348,6 +1354,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.CustomMenuItems != after.CustomMenuItems {
|
||||
changed = append(changed, "custom_menu_items")
|
||||
}
|
||||
if before.CustomEndpoints != after.CustomEndpoints {
|
||||
changed = append(changed, "custom_endpoints")
|
||||
}
|
||||
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
|
||||
changed = append(changed, "enable_fingerprint_unification")
|
||||
}
|
||||
|
||||
@@ -20,8 +20,9 @@ const (
|
||||
)
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
// Email is lowercased for case-insensitive consistency.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
return verifyCodeKeyPrefix + strings.ToLower(email)
|
||||
}
|
||||
|
||||
// notifyVerifyKey generates the Redis key for notify email verification code.
|
||||
@@ -33,12 +34,12 @@ func notifyVerifyKey(email string) string {
|
||||
|
||||
// passwordResetKey generates the Redis key for password reset token.
|
||||
func passwordResetKey(email string) string {
|
||||
return passwordResetKeyPrefix + email
|
||||
return passwordResetKeyPrefix + strings.ToLower(email)
|
||||
}
|
||||
|
||||
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||
func passwordResetSentAtKey(email string) string {
|
||||
return passwordResetSentAtKeyPrefix + email
|
||||
return passwordResetSentAtKeyPrefix + strings.ToLower(email)
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
|
||||
@@ -283,6 +283,20 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context)
|
||||
return nil
|
||||
}
|
||||
|
||||
return filterVerifiedEmails(entries)
|
||||
}
|
||||
|
||||
// getSiteName reads site name from settings with fallback.
|
||||
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
|
||||
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||
if err != nil || name == "" {
|
||||
return defaultSiteName
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
|
||||
func filterVerifiedEmails(entries []NotifyEmailEntry) []string {
|
||||
var recipients []string
|
||||
seen := make(map[string]bool)
|
||||
for _, entry := range entries {
|
||||
@@ -303,38 +317,10 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context)
|
||||
return recipients
|
||||
}
|
||||
|
||||
// getSiteName reads site name from settings with fallback.
|
||||
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
|
||||
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||
if err != nil || name == "" {
|
||||
return defaultSiteName
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
|
||||
// Only emails with verified=true and disabled=false are included.
|
||||
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
|
||||
var recipients []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, entry := range user.BalanceNotifyExtraEmails {
|
||||
if entry.Disabled || !entry.Verified {
|
||||
continue
|
||||
}
|
||||
email := strings.TrimSpace(entry.Email)
|
||||
if email == "" {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(email)
|
||||
if seen[lower] {
|
||||
continue
|
||||
}
|
||||
seen[lower] = true
|
||||
recipients = append(recipients, email)
|
||||
}
|
||||
|
||||
return recipients
|
||||
return filterVerifiedEmails(user.BalanceNotifyExtraEmails)
|
||||
}
|
||||
|
||||
// sendEmails sends an email to all recipients with shared timeout and error logging.
|
||||
|
||||
@@ -55,6 +55,7 @@ type VerificationCodeData struct {
|
||||
Code string
|
||||
Attempts int
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts
|
||||
}
|
||||
|
||||
// PasswordResetTokenData represents password reset token data
|
||||
@@ -263,6 +264,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(verifyCodeTTL),
|
||||
}
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
return fmt.Errorf("save verify code: %w", err)
|
||||
@@ -295,7 +297,11 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||
data.Attempts++
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
remaining := time.Until(data.ExpiresAt)
|
||||
if remaining <= 0 {
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil {
|
||||
slog.Error("failed to update verification attempt count", "email", email, "error", err)
|
||||
}
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
|
||||
@@ -1194,12 +1194,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.hydrateSelectedAccount(ctx, account)
|
||||
}
|
||||
|
||||
// antigravity 分组、强制平台模式或无分组使用单平台选择
|
||||
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
|
||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.hydrateSelectedAccount(ctx, account)
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
@@ -1275,11 +1283,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
localExcluded[account.ID] = struct{}{} // 排除此账号
|
||||
continue // 重新选择
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
}
|
||||
|
||||
// 对于等待计划的情况,也需要先检查会话限制
|
||||
@@ -1291,26 +1295,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1433,36 +1431,39 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
if s.isAccountSchedulableForSelection(stickyAccount) &&
|
||||
var stickyCacheMissReason string
|
||||
|
||||
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(stickyAccount) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
|
||||
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
|
||||
|
||||
if rpmPass { // 粘性会话窗口费用+RPM 检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if stickyCacheMissReason == "" {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 会话限制已满,继续到负载感知选择
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
@@ -1475,11 +1476,31 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
} else {
|
||||
stickyCacheMissReason = "wait_queue_full"
|
||||
}
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
} else if !gatePass {
|
||||
stickyCacheMissReason = "gate_check"
|
||||
} else {
|
||||
stickyCacheMissReason = "rpm_red"
|
||||
}
|
||||
|
||||
// 记录粘性缓存未命中的结构化日志
|
||||
if stickyCacheMissReason != "" {
|
||||
baseRPM := stickyAccount.GetBaseRPM()
|
||||
var currentRPM int
|
||||
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
|
||||
currentRPM = count
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
|
||||
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
|
||||
}
|
||||
} else {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
|
||||
stickyAccountID, shortSessionHash(sessionHash))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1544,11 +1565,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1561,15 +1578,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{
|
||||
AccountID: item.account.ID,
|
||||
MaxConcurrency: item.account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
|
||||
}
|
||||
@@ -1603,11 +1617,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
if s.cache != nil {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
}
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1617,15 +1630,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1684,7 +1694,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||
if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil {
|
||||
return nil, legacyErr
|
||||
} else if ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
@@ -1723,11 +1735,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: selected.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1750,20 +1758,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||
continue // 会话限制已满,尝试下一个账号
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
@@ -1778,15 +1783,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, true
|
||||
selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return selection, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
@@ -2401,6 +2406,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
|
||||
if account == nil || s.schedulerSnapshot == nil {
|
||||
return account, nil
|
||||
}
|
||||
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hydrated == nil {
|
||||
return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID)
|
||||
}
|
||||
return hydrated, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
|
||||
hydrated, err := s.hydrateSelectedAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: hydrated,
|
||||
Acquired: acquired,
|
||||
ReleaseFunc: release,
|
||||
WaitPlan: waitPlan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// filterByMinPriority 过滤出优先级最小的账号集合
|
||||
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
|
||||
if len(accounts) == 0 {
|
||||
@@ -2676,6 +2708,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if groupID != nil && s.groupRepo != nil {
|
||||
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
@@ -2747,6 +2785,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -2852,6 +2896,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -2918,6 +2968,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if groupID != nil && s.groupRepo != nil {
|
||||
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
@@ -2985,6 +3041,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
@@ -3078,6 +3140,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
@@ -3090,6 +3153,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
@@ -3257,8 +3326,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
||||
return selectionFailureDiagnosis{Category: "excluded"}
|
||||
}
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
detail := "generic_unschedulable"
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||
}
|
||||
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
||||
return selectionFailureDiagnosis{
|
||||
@@ -3282,7 +3350,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
||||
return selectionFailureDiagnosis{Category: "eligible"}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
||||
if acc == nil {
|
||||
return true
|
||||
@@ -3653,6 +3720,86 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
return result
|
||||
}
|
||||
|
||||
// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages,
|
||||
// system 字段仅保留 Claude Code 标识提示词。
|
||||
// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词
|
||||
// 无法通过检测,因为后续内容仍为非 Claude Code 格式。
|
||||
// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。
|
||||
func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
|
||||
system = normalizeSystemParam(system)
|
||||
|
||||
// 1. 提取原始 system prompt 文本
|
||||
var originalSystemText string
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
originalSystemText = strings.TrimSpace(v)
|
||||
case []any:
|
||||
var parts []string
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
originalSystemText = strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
// 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致)
|
||||
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
|
||||
// 使用 string 格式会被 Anthropic 检测为第三方应用。
|
||||
claudeCodeSystemBlock := []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": claudeCodeSystemPrompt,
|
||||
"cache_control": map[string]string{"type": "ephemeral"},
|
||||
},
|
||||
}
|
||||
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
|
||||
if !ok {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
|
||||
return body
|
||||
}
|
||||
|
||||
// 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头
|
||||
// 模型仍通过 messages 接收完整指令,保留客户端功能
|
||||
ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||
if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
|
||||
instrMsg, err1 := json.Marshal(map[string]any{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": "[System Instructions]\n" + originalSystemText},
|
||||
},
|
||||
})
|
||||
ackMsg, err2 := json.Marshal(map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": "Understood. I will follow these instructions."},
|
||||
},
|
||||
})
|
||||
if err1 != nil || err2 != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection")
|
||||
return out
|
||||
}
|
||||
|
||||
// 重建 messages 数组:[instruction, ack, ...originalMessages]
|
||||
items := [][]byte{instrMsg, ackMsg}
|
||||
messagesResult := gjson.GetBytes(out, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messagesResult.ForEach(func(_, msg gjson.Result) bool {
|
||||
items = append(items, []byte(msg.Raw))
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk {
|
||||
out = next
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
type cacheControlPath struct {
|
||||
path string
|
||||
log string
|
||||
@@ -3819,7 +3966,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
||||
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
||||
if account.Platform == PlatformAnthropic && c != nil {
|
||||
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
|
||||
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model)
|
||||
if policy.blockErr != nil {
|
||||
return nil, policy.blockErr
|
||||
}
|
||||
@@ -3849,19 +3996,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
||||
// 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
|
||||
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||
systemRewritten := false
|
||||
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||
body = injectClaudeCodePrompt(body, parsed.System)
|
||||
body = rewriteSystemForNonClaudeCode(body, parsed.System)
|
||||
systemRewritten = true
|
||||
}
|
||||
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
|
||||
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
|
||||
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
|
||||
if s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil && fp != nil {
|
||||
// metadata 透传开启时跳过 metadata 注入
|
||||
_, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx)
|
||||
_, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx)
|
||||
if !mimicMPT {
|
||||
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
|
||||
normalizeOpts.injectMetadata = true
|
||||
@@ -5407,9 +5559,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
|
||||
var fingerprint *Fingerprint
|
||||
enableFP, enableMPT := true, false
|
||||
enableFP, enableMPT, enableCCH := true, false, false
|
||||
if s.settingService != nil {
|
||||
enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||
enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||
}
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||
@@ -5436,6 +5588,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
}
|
||||
|
||||
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
|
||||
if fingerprint != nil {
|
||||
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
|
||||
}
|
||||
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
|
||||
if enableCCH {
|
||||
body = signBillingHeaderCCH(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -5476,9 +5637,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
|
||||
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
|
||||
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
|
||||
effectiveDropSet := mergeDropSets(policyFilterSet)
|
||||
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
|
||||
|
||||
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
|
||||
if tokenType == "oauth" {
|
||||
@@ -5489,11 +5649,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
applyClaudeCodeMimicHeaders(req, reqStream)
|
||||
|
||||
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||
// Match real Claude CLI traffic (per mitmproxy reports):
|
||||
// messages requests typically use only oauth + interleaved-thinking.
|
||||
// Also drop claude-code beta if a downstream client added it.
|
||||
// Claude Code OAuth credentials are scoped to Claude Code.
|
||||
// Non-haiku models MUST include claude-code beta for Anthropic to recognize
|
||||
// this as a legitimate Claude Code request; without it, the request is
|
||||
// rejected as third-party ("out of extra usage").
|
||||
// Haiku models are exempt from third-party detection and don't need it.
|
||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
|
||||
if !strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||
}
|
||||
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
|
||||
} else {
|
||||
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
||||
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
|
||||
@@ -5716,7 +5881,7 @@ type betaPolicyResult struct {
|
||||
}
|
||||
|
||||
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
|
||||
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
|
||||
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult {
|
||||
if s.settingService == nil {
|
||||
return betaPolicyResult{}
|
||||
}
|
||||
@@ -5731,10 +5896,11 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
switch rule.Action {
|
||||
effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
|
||||
switch effectiveAction {
|
||||
case BetaPolicyActionBlock:
|
||||
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
|
||||
msg := rule.ErrorMessage
|
||||
msg := effectiveErrMsg
|
||||
if msg == "" {
|
||||
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
||||
}
|
||||
@@ -5776,7 +5942,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet"
|
||||
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
|
||||
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
|
||||
// evaluates on demand (one DB call).
|
||||
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
|
||||
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} {
|
||||
if c != nil {
|
||||
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
|
||||
if fs, ok := v.(map[string]struct{}); ok {
|
||||
@@ -5784,7 +5950,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.evaluateBetaPolicy(ctx, "", account).filterSet
|
||||
return s.evaluateBetaPolicy(ctx, "", account, model).filterSet
|
||||
}
|
||||
|
||||
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
|
||||
@@ -5803,6 +5969,33 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// matchModelWhitelist checks if a model matches any pattern in the whitelist.
|
||||
// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching.
|
||||
func matchModelWhitelist(model string, whitelist []string) bool {
|
||||
for _, pattern := range whitelist {
|
||||
if matchModelPattern(pattern, model) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveRuleAction determines the effective action and error message for a rule given the request model.
|
||||
// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally.
|
||||
// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others.
|
||||
func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) {
|
||||
if len(rule.ModelWhitelist) == 0 {
|
||||
return rule.Action, rule.ErrorMessage
|
||||
}
|
||||
if matchModelWhitelist(model, rule.ModelWhitelist) {
|
||||
return rule.Action, rule.ErrorMessage
|
||||
}
|
||||
if rule.FallbackAction != "" {
|
||||
return rule.FallbackAction, rule.FallbackErrorMessage
|
||||
}
|
||||
return BetaPolicyActionPass, "" // default fallback: pass (fail-open)
|
||||
}
|
||||
|
||||
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
|
||||
func droppedBetaSet(extra ...string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
|
||||
@@ -5849,7 +6042,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
|
||||
modelID string,
|
||||
) ([]string, error) {
|
||||
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
|
||||
policy := s.evaluateBetaPolicy(ctx, betaHeader, account)
|
||||
policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID)
|
||||
if policy.blockErr != nil {
|
||||
return nil, policy.blockErr
|
||||
}
|
||||
@@ -5861,7 +6054,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
|
||||
// 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
|
||||
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
|
||||
// 如果不做此检查,block 规则会被绕过。
|
||||
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil {
|
||||
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil {
|
||||
return nil, blockErr
|
||||
}
|
||||
|
||||
@@ -5870,7 +6063,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
|
||||
|
||||
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
|
||||
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
|
||||
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError {
|
||||
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError {
|
||||
if s.settingService == nil || len(tokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -5882,14 +6075,15 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
|
||||
isBedrock := account.IsBedrock()
|
||||
tokenSet := buildBetaTokenSet(tokens)
|
||||
for _, rule := range settings.Rules {
|
||||
if rule.Action != BetaPolicyActionBlock {
|
||||
effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
|
||||
if effectiveAction != BetaPolicyActionBlock {
|
||||
continue
|
||||
}
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
if _, present := tokenSet[rule.BetaToken]; present {
|
||||
msg := rule.ErrorMessage
|
||||
msg := effectiveErrMsg
|
||||
if msg == "" {
|
||||
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
||||
}
|
||||
@@ -7146,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
|
||||
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
|
||||
}
|
||||
|
||||
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
||||
// - 订阅/余额扣费
|
||||
// - API Key 配额更新
|
||||
// - API Key 限速用量更新
|
||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||
// postUsageBilling is the legacy fallback billing path used when the unified
|
||||
// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply
|
||||
// for atomic billing. This path only runs in tests or degraded mode.
|
||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||
billingCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
cost := p.Cost
|
||||
|
||||
// 1. 订阅 / 余额扣费
|
||||
if p.IsSubscriptionBill {
|
||||
if cost.TotalCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. API Key 配额
|
||||
if p.shouldDeductAPIKeyQuota() {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if p.shouldUpdateRateLimits() {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if p.shouldUpdateAccountQuota() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
@@ -7196,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
}
|
||||
|
||||
finalizePostUsageBilling(p, deps)
|
||||
// NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
|
||||
// cache updates. The legacy path does DB writes directly; the finalize path
|
||||
// does cache queue + notifications. Notifications are dispatched separately
|
||||
// by the caller after recording the usage log.
|
||||
}
|
||||
|
||||
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
||||
@@ -7250,9 +7439,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
||||
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||
cmd.ImageCount = usageLog.ImageCount
|
||||
if usageLog.MediaType != nil {
|
||||
cmd.MediaType = *usageLog.MediaType
|
||||
}
|
||||
if usageLog.ServiceTier != nil {
|
||||
cmd.ServiceTier = *usageLog.ServiceTier
|
||||
}
|
||||
@@ -7315,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
|
||||
}
|
||||
}
|
||||
|
||||
finalizePostUsageBilling(p, deps)
|
||||
finalizePostUsageBilling(p, deps, result)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
|
||||
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
|
||||
if p == nil || p.Cost == nil || deps == nil {
|
||||
return
|
||||
}
|
||||
@@ -7338,22 +7524,82 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
|
||||
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
|
||||
// Balance low notification — use real-time balance from billing cache (not stale snapshot)
|
||||
if !p.IsSubscriptionBill && p.Cost.ActualCost > 0 && p.User != nil && deps.balanceNotifyService != nil {
|
||||
oldBalance := p.User.Balance // fallback to snapshot
|
||||
if deps.billingCacheService != nil {
|
||||
if realBalance, err := deps.billingCacheService.GetUserBalance(context.Background(), p.User.ID); err == nil {
|
||||
oldBalance = realBalance + p.Cost.ActualCost // DB already deducted, reconstruct pre-deduction balance
|
||||
// Notification checks run async — all parameters are already captured,
|
||||
// no dependency on the request context or upstream connection.
|
||||
go notifyBalanceLow(p, deps, result)
|
||||
go notifyAccountQuota(p, deps, result)
|
||||
}
|
||||
|
||||
// notifyBalanceLow sends balance low notification after deduction.
|
||||
// When result.NewBalance is available (from DB transaction RETURNING), it is used directly
|
||||
// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races.
|
||||
func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("panic in notifyBalanceLow", "recover", r)
|
||||
}
|
||||
}
|
||||
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
|
||||
}()
|
||||
if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil {
|
||||
slog.Debug("notifyBalanceLow: skipped",
|
||||
"is_subscription", p.IsSubscriptionBill,
|
||||
"actual_cost", p.Cost.ActualCost,
|
||||
"user_nil", p.User == nil,
|
||||
"service_nil", deps.balanceNotifyService == nil,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Account quota notification (use same cost formula as postUsageBilling)
|
||||
if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil {
|
||||
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost)
|
||||
oldBalance := resolveOldBalance(p, result)
|
||||
slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction",
|
||||
"user_id", p.User.ID,
|
||||
"old_balance", oldBalance,
|
||||
"cost", p.Cost.ActualCost,
|
||||
"notify_enabled", p.User.BalanceNotifyEnabled,
|
||||
"threshold", p.User.BalanceNotifyThreshold,
|
||||
"result_has_new_balance", result != nil && result.NewBalance != nil,
|
||||
)
|
||||
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
|
||||
}
|
||||
|
||||
// resolveOldBalance returns the pre-deduction balance.
|
||||
// Prefers the DB transaction result (newBalance + cost) over snapshot.
|
||||
func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 {
|
||||
if result != nil && result.NewBalance != nil {
|
||||
return *result.NewBalance + p.Cost.ActualCost
|
||||
}
|
||||
// Legacy fallback: snapshot balance from request context
|
||||
return p.User.Balance
|
||||
}
|
||||
|
||||
// notifyAccountQuota sends account quota threshold notification after increment.
|
||||
// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly
|
||||
// to avoid a separate DB read that may see stale or concurrently-modified data.
|
||||
func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("panic in notifyAccountQuota", "recover", r)
|
||||
}
|
||||
}()
|
||||
if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil {
|
||||
slog.Debug("notifyAccountQuota: skipped",
|
||||
"total_cost", p.Cost.TotalCost,
|
||||
"account_nil", p.Account == nil,
|
||||
"is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(),
|
||||
"service_nil", deps.balanceNotifyService == nil,
|
||||
)
|
||||
return
|
||||
}
|
||||
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
var quotaState *AccountQuotaState
|
||||
if result != nil {
|
||||
quotaState = result.QuotaState
|
||||
}
|
||||
slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement",
|
||||
"account_id", p.Account.ID,
|
||||
"account_cost", accountCost,
|
||||
"has_quota_state", quotaState != nil,
|
||||
)
|
||||
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState)
|
||||
}
|
||||
|
||||
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
@@ -7422,11 +7668,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
||||
|
||||
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
||||
type recordUsageOpts struct {
|
||||
// ParsedRequest(可选,仅 Claude 路径传入)
|
||||
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
|
||||
ParsedRequest *ParsedRequest
|
||||
|
||||
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
||||
// - MediaType 字段写入使用日志
|
||||
// - Claude Max 缓存计费策略
|
||||
EnableClaudePath bool
|
||||
|
||||
// 长上下文计费(仅 Gemini 路径需要)
|
||||
@@ -7451,7 +7697,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
APIKeyService: input.APIKeyService,
|
||||
ChannelUsageFields: input.ChannelUsageFields,
|
||||
}, &recordUsageOpts{
|
||||
ParsedRequest: input.ParsedRequest,
|
||||
EnableClaudePath: true,
|
||||
})
|
||||
}
|
||||
@@ -7517,6 +7762,7 @@ type recordUsageCoreInput struct {
|
||||
|
||||
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
||||
// opts 中的字段控制两者之间的差异行为:
|
||||
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
||||
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
||||
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
||||
result := input.Result
|
||||
@@ -7583,13 +7829,10 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
|
||||
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
|
||||
if apiKey.GroupID != nil {
|
||||
upstreamModel := result.UpstreamModel
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = result.Model
|
||||
}
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, upstreamModel,
|
||||
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
|
||||
// Anthropic's input_tokens excludes cache_read and cache_creation (billed separately);
|
||||
// OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason.
|
||||
UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -7597,7 +7840,6 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
},
|
||||
1, // requestCount
|
||||
cost.TotalCost,
|
||||
)
|
||||
}
|
||||
@@ -7796,13 +8038,12 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
BillingMode: resolveBillingMode(opts, result, cost),
|
||||
BillingMode: resolveBillingMode(result, cost),
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||
MediaType: resolveMediaType(opts, result),
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
||||
@@ -7826,7 +8067,7 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
}
|
||||
|
||||
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
||||
func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
|
||||
var mode string
|
||||
switch {
|
||||
case cost != nil && cost.BillingMode != "":
|
||||
@@ -7839,10 +8080,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
|
||||
return &mode
|
||||
}
|
||||
|
||||
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
||||
if subscription != nil {
|
||||
return &subscription.ID
|
||||
@@ -8349,9 +8586,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
|
||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||
ctEnableFP, ctEnableMPT := true, false
|
||||
ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
|
||||
if s.settingService != nil {
|
||||
ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||
ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||
}
|
||||
var ctFingerprint *Fingerprint
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
@@ -8369,6 +8606,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
}
|
||||
|
||||
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
|
||||
if ctFingerprint != nil && ctEnableFP {
|
||||
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
|
||||
}
|
||||
if ctEnableCCH {
|
||||
body = signBillingHeaderCCH(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -8409,7 +8654,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
|
||||
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
|
||||
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
|
||||
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
|
||||
@@ -128,3 +128,66 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// --- validatePlanPatch: other fields ---
|
||||
|
||||
func ptrStr(s string) *string { return &s }
|
||||
func ptrInt(i int) *int { return &i }
|
||||
func ptrInt64(i int64) *int64 { return &i }
|
||||
func ptrFloat(f float64) *float64 { return &f }
|
||||
|
||||
func TestValidatePlanPatch_EmptyName(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "plan name")
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_ValidName(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_ZeroGroupID(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "group")
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_NegativePrice(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "price")
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_ZeroPrice(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "price")
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_ValidPrice(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "validity days")
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "validity unit")
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidatePlanPatch_AllNil(t *testing.T) {
|
||||
err := validatePlanPatch(UpdatePlanRequest{})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -330,6 +330,7 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(verifyCodeTTL),
|
||||
}
|
||||
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
return fmt.Errorf("save verify code: %w", err)
|
||||
@@ -370,7 +371,11 @@ func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string)
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||
data.Attempts++
|
||||
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
remaining := time.Until(data.ExpiresAt)
|
||||
if remaining <= 0 {
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil {
|
||||
slog.Error("failed to update notify verify code attempts", "email", email, "error", err)
|
||||
}
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
@@ -418,11 +423,17 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email
|
||||
}
|
||||
|
||||
filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
|
||||
found := false
|
||||
for _, e := range user.BalanceNotifyExtraEmails {
|
||||
if !strings.EqualFold(e.Email, email) {
|
||||
if strings.EqualFold(e.Email, email) {
|
||||
found = true
|
||||
} else {
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
|
||||
}
|
||||
user.BalanceNotifyExtraEmails = filtered
|
||||
return s.userRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
<script setup lang="ts">
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import QuotaNotifyToggle from './QuotaNotifyToggle.vue'
|
||||
import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
@@ -11,9 +12,9 @@ const props = defineProps<{
|
||||
quotaNotifyGlobalEnabled: boolean
|
||||
notifyEnabled: boolean | null
|
||||
notifyThreshold: number | null
|
||||
notifyThresholdType: string | null
|
||||
notifyThresholdType: QuotaThresholdType | null
|
||||
// Reset mode (only for daily/weekly, null for total)
|
||||
resetMode: 'rolling' | 'fixed' | null
|
||||
resetMode: QuotaResetMode | null
|
||||
resetHour: number | null
|
||||
resetDay: number | null // weekly only
|
||||
resetTimezone: string | null
|
||||
@@ -22,14 +23,15 @@ const props = defineProps<{
|
||||
// Shared options passed from parent
|
||||
hourOptions: number[]
|
||||
dayOptions: { value: number; key: string }[]
|
||||
timezoneOptions?: string[]
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:limit': [value: number | null]
|
||||
'update:notifyEnabled': [value: boolean | null]
|
||||
'update:notifyThreshold': [value: number | null]
|
||||
'update:notifyThresholdType': [value: string | null]
|
||||
'update:resetMode': [value: 'rolling' | 'fixed' | null]
|
||||
'update:notifyThresholdType': [value: QuotaThresholdType | null]
|
||||
'update:resetMode': [value: QuotaResetMode | null]
|
||||
'update:resetHour': [value: number | null]
|
||||
'update:resetDay': [value: number | null]
|
||||
'update:resetTimezone': [value: string | null]
|
||||
@@ -43,7 +45,7 @@ const onLimitInput = (e: Event) => {
|
||||
}
|
||||
|
||||
const onModeChange = (e: Event) => {
|
||||
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
|
||||
const val = (e.target as HTMLSelectElement).value as QuotaResetMode
|
||||
emit('update:resetMode', val)
|
||||
if (val === 'fixed') {
|
||||
if (props.resetHour == null) emit('update:resetHour', 0)
|
||||
@@ -51,6 +53,17 @@ const onModeChange = (e: Event) => {
|
||||
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
|
||||
}
|
||||
}
|
||||
|
||||
function getTimezoneOffsetLabel(tz: string): string {
|
||||
try {
|
||||
const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' })
|
||||
const parts = dtf.formatToParts(new Date())
|
||||
const tzPart = parts.find(p => p.type === 'timeZoneName')
|
||||
return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : ''
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -95,6 +108,11 @@ const onModeChange = (e: Event) => {
|
||||
<select :value="resetHour ?? 0" @change="emit('update:resetHour', Number(($event.target as HTMLSelectElement).value))" class="input py-1 text-xs w-24">
|
||||
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
|
||||
</select>
|
||||
<template v-if="timezoneOptions && timezoneOptions.length > 0">
|
||||
<select :value="resetTimezone || 'UTC'" @change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)" class="input py-1 text-xs w-auto">
|
||||
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }} ({{ getTimezoneOffsetLabel(tz) }})</option>
|
||||
</select>
|
||||
</template>
|
||||
</template>
|
||||
<span class="text-[11px] text-gray-500 dark:text-gray-400">
|
||||
<template v-if="resetMode === 'fixed'">{{ hintFixed }}</template>
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import QuotaDimensionRow from './QuotaDimensionRow.vue'
|
||||
import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
@@ -9,22 +10,22 @@ const props = withDefaults(defineProps<{
|
||||
totalLimit: number | null
|
||||
dailyLimit: number | null
|
||||
weeklyLimit: number | null
|
||||
dailyResetMode: 'rolling' | 'fixed' | null
|
||||
dailyResetMode: QuotaResetMode | null
|
||||
dailyResetHour: number | null
|
||||
weeklyResetMode: 'rolling' | 'fixed' | null
|
||||
weeklyResetMode: QuotaResetMode | null
|
||||
weeklyResetDay: number | null
|
||||
weeklyResetHour: number | null
|
||||
resetTimezone: string | null
|
||||
quotaNotifyGlobalEnabled?: boolean
|
||||
quotaNotifyDailyEnabled?: boolean | null
|
||||
quotaNotifyDailyThreshold?: number | null
|
||||
quotaNotifyDailyThresholdType?: string | null
|
||||
quotaNotifyDailyThresholdType?: QuotaThresholdType | null
|
||||
quotaNotifyWeeklyEnabled?: boolean | null
|
||||
quotaNotifyWeeklyThreshold?: number | null
|
||||
quotaNotifyWeeklyThresholdType?: string | null
|
||||
quotaNotifyWeeklyThresholdType?: QuotaThresholdType | null
|
||||
quotaNotifyTotalEnabled?: boolean | null
|
||||
quotaNotifyTotalThreshold?: number | null
|
||||
quotaNotifyTotalThresholdType?: string | null
|
||||
quotaNotifyTotalThresholdType?: QuotaThresholdType | null
|
||||
}>(), {
|
||||
quotaNotifyGlobalEnabled: false,
|
||||
quotaNotifyDailyEnabled: null,
|
||||
@@ -42,21 +43,21 @@ const emit = defineEmits<{
|
||||
'update:totalLimit': [value: number | null]
|
||||
'update:dailyLimit': [value: number | null]
|
||||
'update:weeklyLimit': [value: number | null]
|
||||
'update:dailyResetMode': [value: 'rolling' | 'fixed' | null]
|
||||
'update:dailyResetMode': [value: QuotaResetMode | null]
|
||||
'update:dailyResetHour': [value: number | null]
|
||||
'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null]
|
||||
'update:weeklyResetMode': [value: QuotaResetMode | null]
|
||||
'update:weeklyResetDay': [value: number | null]
|
||||
'update:weeklyResetHour': [value: number | null]
|
||||
'update:resetTimezone': [value: string | null]
|
||||
'update:quotaNotifyDailyEnabled': [value: boolean | null]
|
||||
'update:quotaNotifyDailyThreshold': [value: number | null]
|
||||
'update:quotaNotifyDailyThresholdType': [value: string | null]
|
||||
'update:quotaNotifyDailyThresholdType': [value: QuotaThresholdType | null]
|
||||
'update:quotaNotifyWeeklyEnabled': [value: boolean | null]
|
||||
'update:quotaNotifyWeeklyThreshold': [value: number | null]
|
||||
'update:quotaNotifyWeeklyThresholdType': [value: string | null]
|
||||
'update:quotaNotifyWeeklyThresholdType': [value: QuotaThresholdType | null]
|
||||
'update:quotaNotifyTotalEnabled': [value: boolean | null]
|
||||
'update:quotaNotifyTotalThreshold': [value: number | null]
|
||||
'update:quotaNotifyTotalThresholdType': [value: string | null]
|
||||
'update:quotaNotifyTotalThresholdType': [value: QuotaThresholdType | null]
|
||||
}>()
|
||||
|
||||
const enabled = computed(() =>
|
||||
@@ -89,11 +90,6 @@ watch(localEnabled, (val) => {
|
||||
}
|
||||
})
|
||||
|
||||
// Whether any fixed mode is active (to show timezone selector)
|
||||
const hasFixedMode = computed(() =>
|
||||
props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed'
|
||||
)
|
||||
|
||||
// Common timezone options
|
||||
const timezoneOptions = [
|
||||
'UTC', 'Asia/Shanghai', 'Asia/Tokyo', 'Asia/Seoul', 'Asia/Singapore', 'Asia/Kolkata',
|
||||
@@ -102,18 +98,6 @@ const timezoneOptions = [
|
||||
'America/Sao_Paulo', 'Australia/Sydney', 'Pacific/Auckland',
|
||||
]
|
||||
|
||||
// Compute GMT offset label (e.g. "GMT+8", "GMT-5") for a given IANA timezone.
|
||||
function getTimezoneOffsetLabel(tz: string): string {
|
||||
try {
|
||||
const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' })
|
||||
const parts = dtf.formatToParts(new Date())
|
||||
const tzPart = parts.find(p => p.type === 'timeZoneName')
|
||||
return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : ''
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
// Hours for dropdown (0-23)
|
||||
const hourOptions = Array.from({ length: 24 }, (_, i) => i)
|
||||
|
||||
@@ -197,6 +181,7 @@ const dailyFixedHint = computed(() =>
|
||||
:hint-fixed="dailyFixedHint"
|
||||
:hour-options="hourOptions"
|
||||
:day-options="dayOptions"
|
||||
:timezone-options="timezoneOptions"
|
||||
@update:limit="emit('update:dailyLimit', $event)"
|
||||
@update:notify-enabled="emit('update:quotaNotifyDailyEnabled', $event)"
|
||||
@update:notify-threshold="emit('update:quotaNotifyDailyThreshold', $event)"
|
||||
@@ -223,6 +208,7 @@ const dailyFixedHint = computed(() =>
|
||||
:hint-fixed="weeklyFixedHint"
|
||||
:hour-options="hourOptions"
|
||||
:day-options="dayOptions"
|
||||
:timezone-options="timezoneOptions"
|
||||
@update:limit="emit('update:weeklyLimit', $event)"
|
||||
@update:notify-enabled="emit('update:quotaNotifyWeeklyEnabled', $event)"
|
||||
@update:notify-threshold="emit('update:quotaNotifyWeeklyThreshold', $event)"
|
||||
@@ -233,14 +219,6 @@ const dailyFixedHint = computed(() =>
|
||||
@update:reset-timezone="emit('update:resetTimezone', $event)"
|
||||
/>
|
||||
|
||||
<!-- Timezone selector (shared by daily/weekly when fixed mode is active) -->
|
||||
<div v-if="hasFixedMode">
|
||||
<label class="input-label">{{ t('admin.accounts.quotaResetTimezone') }}</label>
|
||||
<select :value="resetTimezone || 'UTC'" @change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)" class="input text-sm">
|
||||
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }} ({{ getTimezoneOffsetLabel(tz) }})</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- Total quota -->
|
||||
<QuotaDimensionRow
|
||||
dim="total"
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
<script setup lang="ts">
|
||||
import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE } from '@/constants/account'
|
||||
import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE, type QuotaThresholdType } from '@/constants/account'
|
||||
|
||||
defineProps<{
|
||||
enabled: boolean | null
|
||||
threshold: number | null
|
||||
thresholdType: string | null // "fixed" (default) or "percentage"
|
||||
thresholdType: QuotaThresholdType | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:enabled': [value: boolean | null]
|
||||
'update:threshold': [value: number | null]
|
||||
'update:thresholdType': [value: string | null]
|
||||
'update:thresholdType': [value: QuotaThresholdType | null]
|
||||
}>()
|
||||
</script>
|
||||
|
||||
@@ -43,7 +43,7 @@ const emit = defineEmits<{
|
||||
/>
|
||||
<select
|
||||
:value="thresholdType || QUOTA_THRESHOLD_TYPE_FIXED"
|
||||
@change="emit('update:thresholdType', ($event.target as HTMLSelectElement).value)"
|
||||
@change="emit('update:thresholdType', ($event.target as HTMLSelectElement).value as QuotaThresholdType)"
|
||||
class="input py-1 text-xs w-[4.5rem] flex-shrink-0 text-center"
|
||||
>
|
||||
<option :value="QUOTA_THRESHOLD_TYPE_FIXED">$</option>
|
||||
|
||||
@@ -313,10 +313,6 @@
|
||||
<span class="text-gray-400">{{ t('usage.rate') }}</span>
|
||||
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
|
||||
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<span class="text-gray-400">{{ t('usage.original') }}</span>
|
||||
<span class="font-medium text-white">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
|
||||
@@ -325,7 +321,12 @@
|
||||
<span class="text-gray-400">{{ t('usage.userBilled') }}</span>
|
||||
<span class="font-semibold text-green-400">${{ tooltipData?.actual_cost?.toFixed(6) || '0.000000' }}</span>
|
||||
</div>
|
||||
<!-- Account billing (separated from user billing) -->
|
||||
<div class="flex items-center justify-between gap-6 border-t border-gray-700 pt-1.5">
|
||||
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
|
||||
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
|
||||
</div>
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<span class="text-gray-400">{{ t('usage.accountBilled') }}</span>
|
||||
<span class="font-semibold text-green-400">
|
||||
${{ accountBilled({
|
||||
@@ -355,7 +356,8 @@ import { getBillingModeLabel, getBillingModeBadgeClass, BILLING_MODE_TOKEN, BILL
|
||||
/** Compute the account-billed cost for display: (account_stats_cost ?? total_cost) * rate_multiplier */
|
||||
function accountBilled(row: { total_cost?: number | null; account_stats_cost?: number | null; account_rate_multiplier?: number | null }): number {
|
||||
const base = row.account_stats_cost != null ? row.account_stats_cost : (row.total_cost ?? 0)
|
||||
return base * (row.account_rate_multiplier ?? 1)
|
||||
const result = base * (row.account_rate_multiplier ?? 1)
|
||||
return Number.isNaN(result) ? 0 : result
|
||||
}
|
||||
|
||||
import DataTable from '@/components/common/DataTable.vue'
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { reactive, ref } from 'vue'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import { QUOTA_THRESHOLD_TYPE_FIXED } from '@/constants/account'
|
||||
import { QUOTA_THRESHOLD_TYPE_FIXED, type QuotaThresholdType } from '@/constants/account'
|
||||
|
||||
export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const
|
||||
export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
|
||||
@@ -8,7 +8,7 @@ export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
|
||||
interface DimState {
|
||||
enabled: boolean | null
|
||||
threshold: number | null
|
||||
thresholdType: string | null
|
||||
thresholdType: QuotaThresholdType | null
|
||||
}
|
||||
|
||||
export function useQuotaNotifyState() {
|
||||
@@ -34,7 +34,7 @@ export function useQuotaNotifyState() {
|
||||
for (const d of QUOTA_NOTIFY_DIMS) {
|
||||
state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null
|
||||
state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null
|
||||
state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as string) ?? null
|
||||
state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as QuotaThresholdType) ?? null
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,3 +8,8 @@ export type WebSearchMode = typeof WEB_SEARCH_MODE_DEFAULT | typeof WEB_SEARCH_M
|
||||
export const QUOTA_THRESHOLD_TYPE_FIXED = 'fixed' as const
|
||||
export const QUOTA_THRESHOLD_TYPE_PERCENTAGE = 'percentage' as const
|
||||
export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOTA_THRESHOLD_TYPE_PERCENTAGE
|
||||
|
||||
/** Quota reset mode values */
|
||||
export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const
|
||||
export const QUOTA_RESET_MODE_FIXED = 'fixed' as const
|
||||
export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED
|
||||
|
||||
@@ -166,8 +166,8 @@
|
||||
class="channel-tab group"
|
||||
:class="activeTab === section.platform ? 'channel-tab-active' : 'channel-tab-inactive'"
|
||||
>
|
||||
<PlatformIcon :platform="section.platform" size="xs" :class="getPlatformTextColor(section.platform)" />
|
||||
<span :class="getPlatformTextColor(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span>
|
||||
<PlatformIcon :platform="section.platform" size="xs" :class="platformTextClass(section.platform)" />
|
||||
<span :class="platformTextClass(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
@@ -246,8 +246,8 @@
|
||||
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
@change="togglePlatform(p)"
|
||||
/>
|
||||
<PlatformIcon :platform="p" size="xs" :class="getPlatformTextColor(p)" />
|
||||
<span :class="getPlatformTextColor(p)">{{ t('admin.groups.platforms.' + p, p) }}</span>
|
||||
<PlatformIcon :platform="p" size="xs" :class="platformTextClass(p)" />
|
||||
<span :class="platformTextClass(p)">{{ t('admin.groups.platforms.' + p, p) }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
@@ -310,9 +310,9 @@
|
||||
class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
@change="toggleGroupInSection(sIdx, group.id)"
|
||||
/>
|
||||
<span :class="['font-medium', getPlatformTextColor(group.platform)]">{{ group.name }}</span>
|
||||
<span :class="['font-medium', platformTextClass(group.platform)]">{{ group.name }}</span>
|
||||
<span
|
||||
:class="['rounded-full px-1 py-0 text-[10px]', getRateBadgeClass(group.platform)]"
|
||||
:class="['rounded-full px-1 py-0 text-[10px]', platformBadgeLightClass(group.platform)]"
|
||||
>{{ group.rate_multiplier }}x</span>
|
||||
<span class="text-[10px] text-gray-400">{{ group.account_count || 0 }}</span>
|
||||
<span
|
||||
@@ -363,7 +363,7 @@
|
||||
:value="srcModel"
|
||||
type="text"
|
||||
class="input flex-1 text-xs"
|
||||
:class="getPlatformTextColor(section.platform)"
|
||||
:class="platformTextClass(section.platform)"
|
||||
:placeholder="t('admin.channels.form.mappingSource', 'Source model')"
|
||||
@change="renameMappingKey(sIdx, srcModel, ($event.target as HTMLInputElement).value)"
|
||||
/>
|
||||
@@ -372,7 +372,7 @@
|
||||
:value="section.model_mapping[srcModel]"
|
||||
type="text"
|
||||
class="input flex-1 text-xs"
|
||||
:class="getPlatformTextColor(section.platform)"
|
||||
:class="platformTextClass(section.platform)"
|
||||
:placeholder="t('admin.channels.form.mappingTarget', 'Target model')"
|
||||
@input="section.model_mapping[srcModel] = ($event.target as HTMLInputElement).value"
|
||||
/>
|
||||
@@ -464,7 +464,7 @@
|
||||
: 'border-gray-200 hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700'"
|
||||
>
|
||||
<input type="checkbox" :checked="rule.group_ids.includes(gid)" class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500" @change="rule.group_ids.includes(gid) ? rule.group_ids.splice(rule.group_ids.indexOf(gid), 1) : rule.group_ids.push(gid)" />
|
||||
<span>{{ getGroupNameById(gid) }}</span>
|
||||
<span :class="['font-medium', platformTextClass(section.platform)]">{{ getGroupNameById(gid) }}</span>
|
||||
</label>
|
||||
</div>
|
||||
<p v-if="section.group_ids.length === 0" class="mt-1 text-xs text-gray-400">
|
||||
@@ -481,7 +481,7 @@
|
||||
:key="accountId"
|
||||
class="inline-flex items-center gap-1 rounded-md border border-primary-300 bg-primary-50 px-2 py-0.5 text-xs dark:border-primary-700 dark:bg-primary-900/20"
|
||||
>
|
||||
<span>{{ getRuleAccountLabel(accountId) }}</span>
|
||||
<span :class="['font-medium', platformTextClass(section.platform)]">{{ getRuleAccountLabel(accountId) }}</span>
|
||||
<button type="button" @click="removeRuleAccount(rule, accountId)" class="text-gray-400 hover:text-red-500">
|
||||
<Icon name="x" size="xs" />
|
||||
</button>
|
||||
@@ -595,7 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types'
|
||||
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
|
||||
import type { AdminGroup, GroupPlatform } from '@/types'
|
||||
import type { Column } from '@/components/common/types'
|
||||
import { platformTextClass } from '@/utils/platformColors'
|
||||
import { platformTextClass, platformBadgeLightClass } from '@/utils/platformColors'
|
||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
||||
import DataTable from '@/components/common/DataTable.vue'
|
||||
@@ -720,26 +720,6 @@ let abortController: AbortController | null = null
|
||||
// ── Platform config ──
|
||||
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity']
|
||||
|
||||
function getPlatformTextColor(platform: string): string {
|
||||
switch (platform) {
|
||||
case 'anthropic': return 'text-orange-600 dark:text-orange-400'
|
||||
case 'openai': return 'text-emerald-600 dark:text-emerald-400'
|
||||
case 'gemini': return 'text-blue-600 dark:text-blue-400'
|
||||
case 'antigravity': return 'text-purple-600 dark:text-purple-400'
|
||||
default: return 'text-gray-600 dark:text-gray-400'
|
||||
}
|
||||
}
|
||||
|
||||
function getRateBadgeClass(platform: string): string {
|
||||
switch (platform) {
|
||||
case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
|
||||
case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ──
|
||||
function formatDate(value: string): string {
|
||||
if (!value) return '-'
|
||||
|
||||
Reference in New Issue
Block a user