Compare commits

...

113 Commits

Author SHA1 Message Date
Wesley Liddick
70d0569f08 Merge pull request #1668 from tyqy12/main
修复 OpenAI 账号限流回流误判:7d 窗口可用时不因 5h 窗口为 0 回写 429
2026-04-15 16:48:48 +08:00
Wesley Liddick
1db32d692b Merge pull request #1666 from touwaeriol/feat/account-cost-display
feat(usage): add account cost display to admin dashboard and usage pages
2026-04-15 16:43:07 +08:00
Wesley Liddick
8fd29082c0 Merge pull request #1663 from touwaeriol/fix/test-dialog-close-during-stream
fix(ui): allow closing test dialog during active SSE stream
2026-04-15 16:40:40 +08:00
Wesley Liddick
9bf079b725 Merge pull request #1655 from touwaeriol/feat/payment-fee-multiplier
feat(payment): balance recharge multiplier and fee rate
2026-04-15 16:40:14 +08:00
erio
e180dd0710 fix(usage): remove label text from inline account cost, keep orange color 2026-04-15 16:09:58 +08:00
erio
a7dd535d47 fix(usage): show account cost inline under cost column, remove separate column
- Cost cell: change gray "A $xxx" to orange "成本 $xxx" with i18n
- Remove standalone account_cost column from column settings (redundant)
2026-04-15 15:59:51 +08:00
erio
db27e8f000 feat(usage): add account cost to breakdown sub-table and admin usage log
- UserBreakdownItem: add AccountCost field + SQL aggregation
- UserBreakdownSubTable: add orange account cost column
- Admin usage table: add account_cost column (after cost, default visible)
- Column settings: add account_cost toggle option
2026-04-15 15:40:40 +08:00
Wesley Liddick
7451b6f9ae 修复 OpenAI 账号限流回流误判:7d 窗口可用时不因 5h 窗口为 0 回写 429 2026-04-15 15:29:52 +08:00
erio
e0b12b7512 fix(usage): put cost label before value in usage stats card 2026-04-15 15:02:21 +08:00
erio
22680dc602 test(usage): add unit tests for account_cost and fix gofmt
- Fix mock for GetModelStatsWithFilters: add account_cost column
- Add assertion: GetStatsWithFilters always returns TotalAccountCost
- New test: GetModelStatsAccountCostColumn verifies scan of AccountCost
- New test: GetGroupStatsAccountCostColumn verifies scan of AccountCost
- New test: GetStatsWithFiltersAlwaysReturnsAccountCost (no AccountID filter)
- Integration test: add TotalAccountCost/TodayAccountCost assertions
- Fix gofmt alignment in usage_log_types.go
2026-04-15 15:02:21 +08:00
erio
6ade6d30a8 feat(usage): add account cost display to admin dashboard and usage pages
- Add account_cost column to dashboard aggregation tables (migration 107)
- DashboardStats: add TotalAccountCost/TodayAccountCost fields
- ModelStat/GroupStat: add AccountCost field with SQL aggregation
- GetStatsWithFilters: always return TotalAccountCost (remove accountID filter)
- Dashboard Token cards: show user(green)/cost(orange)/standard(gray)
- Usage stats card: show account cost and standard below main value
- Model/Group distribution tables: add orange cost column
2026-04-15 15:02:21 +08:00
erio
38c00872e1 fix(ui): allow closing test dialog during active SSE stream
Replace dead EventSource variable with AbortController to enable
cancelling fetch streams. Remove close-button disable during connecting
status so users can dismiss the dialog at any time.
2026-04-15 11:34:31 +08:00
erio
c2108421c2 fix: gofmt payment_service.go and payment_order.go 2026-04-15 01:50:19 +08:00
erio
342dbd2e19 fix(payment): use original recharge amount in product name, not pay_amount
Product name (e.g. "快代码科技工作室 100 元") should show the user's
original recharge amount (limitAmount), not the fee-inclusive pay amount.
The gateway receives payAmount separately for actual charging.
2026-04-15 01:43:56 +08:00
erio
21f22b5099 fix: remove accidentally staged Antigravity-Manager submodule 2026-04-15 01:39:27 +08:00
erio
60614e6f74 fix: gofmt formatting and update API contract test for new fields
- Fix gofmt alignment in setting_handler.go, settings.go, payment_config_service.go
- Add payment_balance_recharge_multiplier and payment_recharge_fee_rate
  to API contract test expected JSON
2026-04-15 01:39:00 +08:00
erio
3053c56cac fix(payment): show full amount breakdown on payment result page
- Show base amount (充值金额) as first line
- Show fee amount with percentage when fee_rate > 0
- Show pay_amount (实付金额) in bold primary color
- Show credited amount (到账金额) when different from pay_amount
- Compute baseAmount and feeAmount from backend order data
2026-04-15 01:27:25 +08:00
erio
d149dbc91f fix(payment): enhance fee rate input validation and UI
Backend:
- Validate recharge_fee_rate: 0 ≤ rate ≤ 100, max 2 decimal places

Frontend settings:
- Add % suffix icon to fee rate input
- Enforce max=100, min=0, step=0.01 with 2 decimal precision
2026-04-15 01:27:24 +08:00
erio
e761d38fd1 fix(payment): integrate recharge fee rate in order flow and fix UI display
Backend:
- Use cfg.RechargeFeeRate in order creation instead of hardcoded 0
- Remove dead getFeeRate stub method
- All amounts computed server-side: order_amount, pay_amount, fee_rate

Frontend - PaymentView:
- Read recharge_fee_rate from checkout-info API (not per-method)
- Show fee breakdown only when fee_rate > 0
- Show credited amount only when multiplier ≠ 1

Frontend - Order display (user + admin):
- Fix fee_rate * 100 bug (fee_rate is already a percentage)
- OrderTable: show pay_amount as primary, fee/credited as sub-lines
- AdminOrderDetail: full breakdown (base/fee/paid/credited)
- AdminRefundDialog: label "到账金额" for clarity
- PaymentResultView: show pay_amount with fee info

Types + i18n:
- Add recharge_fee_rate to CheckoutInfoResponse
- Add fee_rate to CreateOrderResult
- Add translations: creditedAmount, fee, baseAmount, includedInPayAmount
2026-04-15 01:27:24 +08:00
erio
98140f6cac feat(payment): add recharge fee rate setting and fix provider card UI
- Add recharge_fee_rate system setting (percentage fee on top of recharge amount)
- Full backend chain: config constant, PaymentConfig struct, update validation,
  read/write persistence, DTO, handler GET/PUT responses
- Frontend: settings input with preview, i18n (zh/en), API types
- Fix provider card toggle layout: labels above switches to save width
- Fix Chinese translation: "EasyPay" → "易支付" in provider description
2026-04-15 01:27:24 +08:00
erio
60a4b9316b feat(payment): balance recharge multiplier and refund amount separation
- Add balance_recharge_multiplier system setting (e.g. 1.2 = charge 100 get 120)
- Separate order_amount (credited balance) from pay_amount (actual payment)
- Refund calculates gateway amount proportionally from pay_amount
- Frontend shows both amounts in order details, payment status, refund dialog
- Admin settings UI for configuring recharge multiplier
2026-04-15 01:27:24 +08:00
Wesley Liddick
7c671b5373 Merge pull request #1635 from KnowSky404/fix-issue-1613-version-dropdown
fix(sidebar): prevent version dropdown clipping in expanded brand
2026-04-14 20:41:53 +08:00
Wesley Liddick
d402e722cf Merge pull request #1637 from touwaeriol/feat/websearch-notify-pricing
feat: web search emulation, balance/quota notify, account stats pricing, per-provider refund control, Stripe fix / Web 搜索模拟、余额配额通知、渠道统计计费、按服务商退款控制、Stripe 修复
2026-04-14 20:41:09 +08:00
erio
8548a130c7 fix: Messages() routing refactor and subscription group test coverage
- Refactor OpenAI Messages() routing: pre-compute dispatch model using
  resolveOpenAIMessagesDispatchMappedModel + NormalizeOpenAICompatRequestedModel
  instead of try-fail-retry pattern with gin context passing
- Remove openai_messages_fallback_model context anti-pattern
- Use effectiveMappedModel directly for forward default mapped model
- Add 3 subscription group tests covering all branch paths:
  _Blocked (no active subscription → SUBSCRIPTION_REQUIRED),
  _RequiresRepo (nil repo → SUBSCRIPTION_REPOSITORY_UNAVAILABLE),
  _AllowsActiveSubscription (valid subscription → success)
2026-04-14 20:34:53 +08:00
erio
3d2027227b fix: update wire_gen.go to use ProvideSchedulerCache with config injection
wire_gen.go was calling NewSchedulerCache(redisClient) but wire.go had
been updated to register ProvideSchedulerCache(redisClient, config),
which reads SnapshotMGetChunkSize and SnapshotWriteChunkSize from config.
Without this fix, those config values were silently ignored.
2026-04-14 20:22:45 +08:00
erio
3fa5b8bca5 fix: flaky WebSocket test, usage request queue, and test improvements
- Fix flaky WebSocket passthrough test: allow StatusNormalClosure after
  client close instead of requiring NoError (race condition fix)
- Fix ratelimit 401 test: use PlatformOpenAI instead of PlatformGemini
  for OAuth token cache invalidation scenario (more accurate)
- Add usageLoadQueue: Anthropic OAuth/setup-token accounts sharing the
  same proxy exit are serialized with 1-2s jitter to prevent upstream 429
- AccountUsageCell: add module-level usage cache (5min TTL), unmounted
  safety guard, and integrate enqueueUsageRequest for throttled fetching
2026-04-14 20:13:59 +08:00
erio
5240b44452 refactor(payment): inline payment flow, mobile support, renewal modal
Replace dialog-based payment with inline state flow (select → paying/stripe).
- PaymentStatusPanel replaces QR dialog for scan-to-pay
- StripePaymentInline replaces Stripe popup
- Subscription confirm as inline card instead of modal
- Payment button color follows payment method
- Renewal modal with URL parameter navigation (?tab=subscription&group=123)
- Mobile auto-redirect for H5 payment
- AmountInput uses global min/max instead of per-method
- Tab auto-hides during payment
- Restore CNY (¥) currency for upstream compatibility
2026-04-14 19:45:53 +08:00
erio
a56151fec9 refactor: extract CapacityBadge component from AccountCapacityCell
Extract repeated badge template (SVG icon + current/max display) into
a reusable CapacityBadge component. Reduces AccountCapacityCell from
~300 lines to ~180 lines with identical behavior.
2026-04-14 19:39:22 +08:00
erio
63f539b382 fix: merge general improvements from release branch
Backend:
- gateway_handler: pass subject.UserID instead of int64(0) for user-level routing
- setting_handler: add missing BalanceLowNotifyRechargeURL to UpdateSettings response
- openai_gateway_service: use applyAccountStatsCost for account stats pricing integration
- embed_on: add local file override (data/public/) for embedded frontend assets

Frontend:
- useTableSelection: add batchUpdate method for batch operations
- AccountsView: virtual scrolling params, Set-based isSelected, swipe virtualization
- ProxiesView: add batchUpdate to selection and swipe-select
- BulkEditAccountModal: fix submit handler to prevent event object as argument
- SettingsView: move payload construction outside try block
- i18n: add general translation keys (saved, deleted, view, validation, allowUserRefund)
- api/client: reorder error fields for consistency
- stores/payment: clarify pollOrderStatus JSDoc
2026-04-14 19:29:37 +08:00
erio
c14d739360 fix: resolve 3 code review issues in allow_user_refund
1. PrepareRefund: block refund on provider instance lookup failure
   instead of silently skipping permission check (medium severity)

2. UpdateProviderInstance: allow enabling refund_enabled and
   allow_user_refund in the same request by checking req.RefundEnabled
   value before falling back to DB read

3. ExecuteRefund: only revoke subscription on ErrAdjustWouldExpire,
   abort on other errors (DB failure, not found) instead of
   unconditionally revoking
2026-04-14 18:41:09 +08:00
erio
58677dd53f fix: merge 5 PR-related improvements
- gateway_handler: pass ParsedRequest to RecordUsage + set in gin.Context
- channel_handler: add FeaturesConfig to CRUD (WebSearch channel toggle)
- channel_repo: features_config JSONB persistence (Create/Get/Update/List)
- security_headers: add Stripe CSP domains (script-src + frame-src)
2026-04-14 18:34:57 +08:00
erio
6ac8ccde46 fix: merge 30 general improvements from release branch
Bug fixes:
- Detached context for GetAccountConcurrencyBatch (prevent all-zero on request cancel)
- Filter soft-deleted users in GetByGroupID
- Stripe CSP policy (allow Stripe.js in script-src and frame-src)
- WebSearch API key validation on save
- RECHARGING status in payment result success check
- Windows test fixes (logger Sync deadlock, config path escaping)

Feature enhancements:
- Webhook multi-instance dispatch (extractOutTradeNo + GetWebhookProvider)
- EasyPay mobile H5 payment (device param + PayURL2)
- SSE error propagation in WebSearch emulation
- AccountStatsCost DTO field for admin usage logs
- Plans sort by sort_order instead of created_at
- UsageMapHook for streaming response usage data
- apicompat Instructions field passthrough
- EffectiveLoadFactor for ops concurrency/metrics
- Usage billing RETURNING balance for notify system
- BulkUpdate mixed channel warning with details
- println to slog migration in auth cache
- Wire ProviderSet cleanup
- CI cache-dependency-path optimization

Frontend:
- Refund eligibility check per provider (canRequestRefund)
- Plan sort_order editing
- Dead code cleanup (simulate_claude_max, client_affinity)
- GroupsView platform switch guard
- channels features_config API type
- UsageView account_stats_cost export
2026-04-14 17:35:27 +08:00
erio
f1297a3694 feat: add per-provider allow_user_refund control and align wildcard matching
allow_user_refund:
- Add allow_user_refund field to PaymentProviderInstance ent schema
- Migration 103: ALTER TABLE payment_provider_instances ADD COLUMN
- Cascade logic: disabling refund_enabled auto-disables allow_user_refund
- User refund validation: check provider instance allows user refund
- Admin refund validation: check provider instance allows admin refund
- Subscription refund: deduct days on refund, rollback on failure
- New endpoint: GET /payment/orders/refund-eligible-providers
- Frontend: ToggleSwitch in ProviderCard/Dialog, cascade in SettingsView

Wildcard matching:
- Change findPricingForModel from "longest prefix wins" to "config order
  priority (first match wins)", aligning with channel service behavior
2026-04-14 16:26:46 +08:00
erio
e8ee400a3f fix: resolve remaining lint errors for upstream CI
- Fix errcheck: brave.go resp.Body.Close, manager_test.go Encode
- Fix gofmt: payment_config_service.go
- Fix unused: use shouldFallbackGeminiModel (with modelName param) in handler
2026-04-14 12:19:44 +08:00
erio
6a08efeef9 fix: resolve upstream CI failures (lint, test, gofmt)
- Fix errcheck: handle Write/Encode return values in brave_test.go
- Fix errcheck: defer resp.Body.Close() with _ assignment in tavily.go
- Fix gofmt: payment.go, channel.go, payment_config_providers.go
- Fix unused: remove dead decodeURLValue in easypay.go
- Restore shouldFallbackGeminiModel function (deleted during cherry-pick)
- Add missing balanceNotifyService param to NewGatewayService in test
- Fix platform default test expectation (empty stays empty)
- Fix wildcard pricing test (longest prefix wins, not config order)
- Fix subscription group test (SUBSCRIPTION_REPOSITORY_UNAVAILABLE)
2026-04-14 12:11:08 +08:00
erio
4aa0070e3d fix: Stripe payment type matching in load balancer
Checkout page aggregates Stripe sub-types (card,link,alipay,wxpay) under
"stripe", but SelectInstance matched against supported_types literally,
which doesn't contain "stripe". Now matches by provider_key for Stripe.
2026-04-14 11:31:44 +08:00
erio
b42f34c359 fix: resolve test compilation errors and restore upstream VERSION
- Add missing interface methods to test stubs (RemoveGroupFromUserAllowedGroups,
  GetNotifyCodeUserRate, IncrNotifyCodeUserRate, UpdateGroupIDByUserAndGroup)
- Fix NewUserService call signatures (add 4th param)
- Fix GetAccountCount return signature (3 values)
- Update api_contract_test.go snapshots for balance_notify fields
- Restore resolveOpenAIMessagesDispatchMappedModel function
- Reset VERSION to upstream 0.1.112
2026-04-14 11:27:32 +08:00
erio
24e16b7f59 fix: restore resolveOpenAIMessagesDispatchMappedModel and reset VERSION
- Restore function deleted during cherry-pick conflict resolution
- Reset VERSION to upstream 0.1.112
2026-04-14 10:58:51 +08:00
erio
d6965b0676 fix: resolve cherry-pick conflicts and restore compilation
- Restore gateway_cache.go to upstream (no lua embeds)
- Restore payment_order.go to upstream (use out_trade_no lookup)
- Restore payment_fulfillment.go to upstream (same reason)
- Add FeaturesConfig field and IsWebSearchEmulationEnabled to Channel
- Add applyAccountStatsCost wrapper function
- Add SettingKeyWebSearchEmulationConfig constant
- Add WebSearchEmulationEnabled to SystemSettings
- Add notify code rate limiting methods to EmailCache interface
- Remove AllowUserRefund references (ent schema not present)
- Fix duplicate import in payment_handler.go
- Fix wire_gen.go argument mismatches
2026-04-14 10:18:39 +08:00
erio
9028d2085f test: add unit tests for billing, websearch, and notify systems
Billing (25 tests):
- CalculateCostUnified: nil resolver fallback, token/per_request/image modes
- GetModelPricingWithChannel: nil/partial/full channel overrides
- resolveAccountStatsCost: four-level priority chain integration tests

WebSearch (18 tests):
- PopulateWebSearchUsage: nil input, manager states, QuotaLimit nil/*int64
- ResetWebSearchUsage: nil manager error
- Manager.ResetUsage: nil Redis
- shouldEmulateWebSearch: full decision chain (8 scenarios)

Notify (36 tests):
- ParseNotifyEmails/MarshalNotifyEmails: old/new format, roundtrip
- crossedDownward: boundary values, threshold semantics
- checkQuotaDimCrossings: mixed dimensions, disabled/zero skip
2026-04-14 09:36:40 +08:00
erio
7c7292935e feat: websearch quota enhancements and balance notify hint
- QuotaLimit changed to *int64 (null=unlimited, >0=limited)
- Add reset-usage endpoint (POST /admin/settings/web-search-emulation/reset-usage)
- Show quota usage in header always (collapsed and expanded)
- Add reset quota button in expanded provider view
- Quota input: empty=unlimited with ∞ placeholder, must be >0 if set
- Add email verification hint on balance notify card
2026-04-14 09:36:40 +08:00
erio
1e6912ea2e fix: gofmt formatting across all Go source files 2026-04-14 09:36:26 +08:00
erio
9e0d12d3b0 fix: show websearch API key visibility/copy buttons for saved providers
The buttons were hidden because v-if only checked provider.api_key,
which is always empty for saved providers (backend sanitizes it).
Now also checks api_key_configured. Copy button is disabled when
no actual key is available (only configured placeholder shown).
2026-04-14 09:35:46 +08:00
erio
b402c367d3 fix: add opportunistic STARTTLS to sendMailPlain for 587 port compatibility
smtp.SendMail automatically upgrades to STARTTLS when the server
supports it. Our replacement sendMailPlain skipped this, causing
credentials to be sent in plaintext on port 587. Add STARTTLS
negotiation before Auth to restore the original security behavior.
2026-04-14 09:35:21 +08:00
erio
0a4ece5f5b fix: audit round-3 — proxy safety, intervals persistence, SMTP timeout, sort fix
- Skip websearch provider when ProxyID is set but proxy not found (prevent
  silent direct connection bypass)
- Fix sortByStableRandomWeight: pair factors with items so sort.Slice swap
  keeps weights aligned
- Allow empty platform in account_stats_pricing_rules (wildcard matching),
  only force anthropic default for main model_pricing
- Add channel_account_stats_pricing_intervals table and repo layer support
  for interval-based pricing in account stats rules
- calculateTokenStatsCost now uses interval pricing when available
- Replace smtp.SendMail/tls.Dial with net.Dialer timeout (10s dial, 20s IO)
  to prevent goroutine leak on SMTP hang
- Fix gofmt formatting issues
- Web Search label: black text with red warning hint
2026-04-14 09:35:20 +08:00
erio
9c09bd19b4 fix: websearch features_config cleanup and pricing rules validation
- Fix web_search_emulation toggle: explicitly write false for disabled
  platforms instead of leaving stale true from cloned features_config
- Extract validatePricingEntries from validateChannelConfig for reuse
- Validate account_stats_pricing_rules[].pricing in both Create and
  Update paths (negative prices, bad intervals, missing per_request price)
2026-04-14 09:35:20 +08:00
erio
a9880ee7b9 fix: round-2 audit fixes — security, code quality, and UI improvements
Security (HIGH):
- Normalize all Redis cache keys to lowercase (verifyCode, passwordReset)
- Fix verify code TTL renewal on failed attempts: use remaining TTL via
  ExpiresAt field instead of resetting to full 15-minute window
- Add 3 missing fields to diffSettings audit log (promo_code, invitation_code,
  custom_endpoints)

Code quality (MEDIUM):
- Extract filterVerifiedEmails shared helper (balance_notify_service.go)
- Add Pricing array non-empty validation for channel pricing rules
- Add platform token semantics comment in gateway_service.go
- Complete validatePlanPatch test coverage (+10 test cases)
- Replace string types with QuotaThresholdType/QuotaResetMode across frontend
- Remove duplicate getPlatformTextColor/getRateBadgeClass in ChannelsView
- Return EMAIL_NOT_FOUND error on RemoveNotifyEmail miss

UI improvements:
- Reorder cost tooltip: user billing above separator, account billing below
- Add NaN guard to accountBilled function
- Move timezone selector inline into reset-mode row (no longer standalone)
2026-04-14 09:35:05 +08:00
erio
74f8a30f86 fix: address audit findings for websearch, email verification, and pricing
- Fix websearch provider failover: proxy error from provider-specific proxy
  now continues to next provider instead of aborting the entire loop
- Fix SMTP failure locking users out: send email first, then write cache
  and increment rate counter
- Fix notify email cache key case sensitivity: normalize to lowercase
- Add OriginalPrice validation to validatePlanPatch and validatePlanRequired
- Add empty scope validation for channel pricing rules (group_ids/account_ids)
- Add platform color to account search dropdown in channel pricing rules
2026-04-14 09:33:53 +08:00
erio
1b7c295199 refactor: M5 useQuotaNotifyState composable + H14 Vue file splits
M5: New composable frontend/src/composables/useQuotaNotifyState.ts
  - Replaces 9 individual refs in both Create/Edit modals with reactive state
  - Provides loadFromExtra/writeToExtra/reset helpers
  - Eliminates ~120 lines of duplicated code across the two modals

H14: Vue file length violations fixed
  - AdminPaymentPlansView.vue: 325 → 183 lines (extracted PlanEditDialog.vue)
  - QuotaLimitCard.vue: 327 → 268 lines (extracted QuotaDimensionRow.vue)
  - PlanEditDialog.vue: 181 lines (new, plan create/edit form)
  - QuotaDimensionRow.vue: 108 lines (new, single quota dimension row)
2026-04-14 09:33:39 +08:00
erio
594f0d17d1 refactor: batch 3 — decompose CheckBalanceAfterDeduction, merge crossing checks, add QuotaNotifyConfig
M1: CheckBalanceAfterDeduction (63→18 lines) decomposed into:
    canNotifyBalance, resolveUserEffectiveThreshold, crossedDownward, dispatchBalanceLowEmail
M3: New Account.QuotaNotifyConfig(dim) method replaces 9 hardcoded getters
    (getters kept as thin wrappers for backward compatibility)
M4: checkQuotaDimCrossings + checkQuotaDimCrossingsFromState merged into one
    function taking pre-built []quotaDim; caller builds dims conditionally
2026-04-14 09:33:00 +08:00
erio
9d319cfa2d fix: batch 2 audit fixes — diffSettings notify fields, slog migration, frontend constants
H5: diffSettings now tracks 5 balance/quota notify fields in audit log
M15: log.Printf audit log migrated to slog.Info, removed "log" import
M14: New frontend/src/constants/account.ts with shared constants
     QuotaNotifyToggle.vue uses QUOTA_THRESHOLD_TYPE_FIXED/PERCENTAGE
L2: UsageTable.vue uses BILLING_MODE_TOKEN/IMAGE from billingMode.ts
2026-04-14 09:32:24 +08:00
erio
ed8a9d975b fix: batch 1 audit fixes — quota SQL fixed mode, public recharge URL, WebSearch bool fallback, UpdatePlan validation
H1: incrementUsageBillingAccountQuota now uses shared dailyExpiredExpr/weeklyExpiredExpr
    constants (supporting fixed reset mode) instead of hardcoded '24 hours'/'168 hours'
H4: public settings endpoint now maps balance_low_notify_recharge_url
H6: GetWebSearchEmulationMode tolerates legacy bool values (true→enabled)
H7: UpdatePlan validates non-nil patch fields (rejects negative price, empty name, etc.)
H8: UsageTable accountBilled() helper with total_cost ?? 0 null guard
H9: AdminUsageLog TS type adds channel_id + billing_tier
M2: account.go "fixed" literals replaced with thresholdTypeFixed constant
M13: SystemSettings TS type adds web_search_emulation_enabled
UI: QuotaLimitCard title labels now use flex-1 to align with flex-1 input boxes
2026-04-14 09:32:11 +08:00
erio
ca673f9899 test: add 66 unit tests for balance/quota notify + plan validation
balance_notify_service_test.go (27 tests):
- resolveBalanceThreshold: fixed/percentage/zero recharged/empty type
- quotaDim.resolvedThreshold: fixed normal/exceed/equal limit, percentage 0/30/100/>100, zero/negative limit
- sanitizeEmailHeader: CRLF/CR/LF/clean/empty/multiple newlines
- buildQuotaDims / buildQuotaDimsFromState: all dimensions, empty extra, state-vs-account precedence
- collectBalanceNotifyRecipients: empty, filter disabled/unverified, case-insensitive dedup, skip empty, trim

balance_notify_check_test.go (16 tests):
- CheckBalanceAfterDeduction guard clauses: nil user/disabled/global-off/threshold=0/user-override/no-crossing
- CheckAccountQuotaAfterIncrement guards: nil account/zero cost/negative cost/global-disabled
- getBalanceNotifyConfig: all fields, disabled, invalid threshold
- isAccountQuotaNotifyEnabled: missing/false/true
- getSiteName: default fallback + configured

balance_notify_email_body_test.go (10 tests):
- Guards against fmt.Sprintf arg-count mismatches in email templates
- Verifies HTML escaping of recharge URL
- Verifies CSS %% escape produces literal % in output
- Verifies unlimited/percentage/over-quota display branches

payment_config_plans_validation_test.go (13 tests):
- validatePlanRequired: all 5 validation branches + whitespace handling
2026-04-14 09:31:45 +08:00
erio
a43da62254 fix(accounts): unify modal width, add notify props to create, fix quota layout
- EditAccountModal width changed from "normal" to "wide" (match CreateAccountModal)
- CreateAccountModal now passes all quota notify props to QuotaLimitCard
- QuotaLimitCard: when global notify disabled, hide title row, input takes full width
- Quota alert email: show remaining quota + threshold (fixed/$, percentage/%) instead of usage trigger point
2026-04-14 09:31:32 +08:00
erio
6e9146e746 fix(notify): add recharge URL to admin settings GET response 2026-04-14 09:31:08 +08:00
erio
f571d8ffad fix(notify): write back auto-filled recharge URL to form on save 2026-04-14 09:31:08 +08:00
erio
48b6c4811f fix(notify): auto-fill recharge URL with current origin when empty 2026-04-14 09:31:08 +08:00
erio
c1eb79e4ba feat(notify): add platform/ID to quota alert email, add recharge URL to balance alert
- Quota alert email now shows account ID and platform
- Balance low email includes a "Top Up Now" button when recharge URL is configured
- New setting: balance_low_notify_recharge_url in admin settings
2026-04-14 09:30:51 +08:00
erio
e27335acdd fix(ui): widen notify type dropdown to show % fully, align quota input widths 2026-04-14 09:30:02 +08:00
erio
216bda58da fix: change quota notify threshold semantics to "remaining quota"
Threshold now represents remaining quota instead of usage amount:
- Fixed ($): threshold=400, limit=1000 → alert when remaining drops to $400
  (i.e., usage reaches $600)
- Percentage (%): threshold=30%, limit=1000 → alert when remaining drops
  to 30% (i.e., usage reaches $700)

Also:
- Rename 告警阈值 → 提醒阈值 in i18n
- Widen type dropdown to w-16 for proper $ / % display
2026-04-14 09:29:25 +08:00
erio
7141dceee2 fix(frontend): place quota notify toggle inline with limit input
Move QuotaNotifyToggle to the same row as the limit $ input for all
three dimensions (daily/weekly/total), significantly reducing card height.
2026-04-14 09:29:01 +08:00
erio
ac55443278 fix(frontend): collapsible quota card and compact notify layout
- QuotaLimitCard: add collapse/expand toggle (chevron icon + click header)
- QuotaNotifyToggle: show $ or % suffix in threshold input
- Reduce vertical spacing between reset mode hint and notify toggle
2026-04-14 09:28:48 +08:00
erio
2066c478ab fix(frontend): quota notify UI improvements
- QuotaNotifyToggle: add $ or % suffix to threshold input based on type
- QuotaLimitCard: combine reset mode and notify toggle on same row
  to reduce vertical height for daily/weekly sections
- Remove redundant ml-4 indentation from QuotaNotifyToggle
2026-04-14 09:28:24 +08:00
erio
98c9d51791 fix: correct account stats pricing priority order
Priority was wrong:
- Before: custom rules → LiteLLM (when ApplyPricingToAccountStats) → nil
- After:  custom rules → totalCost (when ApplyPricingToAccountStats) → LiteLLM → nil

When ApplyPricingToAccountStats is enabled, use the request's actual
client billing cost (before multiplier) as account_stats_cost, instead
of recalculating from LiteLLM per-token prices which produced incorrect
values for per-request billing mode.

LiteLLM model pricing is now the final fallback (priority 3), used only
when neither custom rules nor ApplyPricingToAccountStats apply.
2026-04-14 09:28:11 +08:00
erio
42f8ef3315 fix: add missing AccountQuotaNotifyEnabled to admin settings API
The field was present in SystemSettings response DTO and service layer
but missing from:
- UpdateSettingsRequest (admin handler) - saves were silently ignored
- GET/PUT response mapping in admin handler
- UpdateSettingsRequest (non-admin dto)

This caused the toggle to always revert to off after saving.
2026-04-14 09:27:47 +08:00
erio
245f47cebb fix(frontend): simplify websearch select labels and reduce width
- "默认(跟随渠道)" → "默认", "Default (follow channel)" → "Default"
- Move "follows channel config" info to description text
- Reduce select width from w-32 to w-24 in both Edit and Create modals
2026-04-14 09:27:46 +08:00
erio
48e8efe3e8 fix(frontend): hide quota notify toggle when global setting is disabled
QuotaLimitCard now requires quotaNotifyGlobalEnabled prop to control
visibility of QuotaNotifyToggle components. When the global account
quota notification is disabled in admin settings, per-account threshold
toggles are hidden in both Edit and Create account modals.
2026-04-14 09:27:33 +08:00
knowsky404
58c0f57647 fix(sidebar): prevent version dropdown clipping in expanded brand 2026-04-14 09:27:02 +08:00
erio
b1875f0b82 fix: round 3 audit fixes - SMTP header sanitization and goroutine safety
- Move sanitizeEmailHeader to SendEmailWithConfig entry point, covering all
  email senders (verify code, password reset, ops alerts, notifications)
- Add panic recovery to UpdateBalance goroutine
- Fix stale comment in getAccountQuotaNotifyEmails (email="" no longer used)
- Log error instead of silently discarding verifyNotifyCode cache update failure
2026-04-14 09:26:46 +08:00
erio
b7fb2e4387 fix: audit fixes for websearch, notifications, and channel pricing
P0: fix wildcard matching test assertion (config order, not longest prefix)
P0: add TotalRecharged to auth cache snapshot (v5) for percentage threshold
P1: move pricing rules into per-platform sections in ChannelsView
P1: populate account name cache when editing existing channel rules
P1: sanitize email subject headers to prevent SMTP injection
P1: make Redis INCR+EXPIRE idempotent for rate limiting
P1: deep copy FeaturesConfig in Channel.Clone()
P2: clean up stale email="" placeholder comments
P2: replace log.Printf with slog in email_service.go
2026-04-14 09:26:32 +08:00
erio
a68df457d8 fix: address audit findings across websearch, notify, and channel pricing
Backend fixes:
- Fix balance notify ignoring percentage threshold type (was treating
  percentage value as fixed USD amount)
- Remove dead code parseJSONStringArray
- Add ImageOutputTokens to tryModelFilePricing calculation
- Unify zero-value check: cost == 0 → cost <= 0 in calculateTokenStatsCost
- Use MarshalNotifyEmails instead of json.Marshal for consistency
- Rename quotaDim.oldUsed → currentUsed for clarity
- Extract HTML email templates to const variables (function ≤30 lines)

Test fixes:
- Rewrite account_websearch_test.go for GetWebSearchEmulationMode tri-state
- Add 6 tryModelFilePricing test cases

Frontend fixes:
- Replace hardcoded '未命名' with i18n key
- Extract getBillingModeLabel/getBillingModeBadgeClass to shared utils
- Replace inline type with imported NotifyEmailEntry
- Pass platform to AccountStats pricing rules via inferRulePlatform()
- Add billing mode constants (BILLING_MODE_TOKEN/PER_REQUEST/IMAGE)
2026-04-14 09:26:08 +08:00
erio
1262654d97 feat: WebSearch tri-state, account stats pricing fix, quota cache fix, usage tooltip
WebSearch tri-state switch:
- Account-level web_search_emulation changed from bool to tri-state
  string: "default" (follow channel) / "enabled" / "disabled"
- shouldEmulateWebSearch checks channel config when account is "default"
- SQL migration converts old bool values
- Frontend select replaces toggle in Edit/CreateAccountModal

Account stats pricing:
- resolveAccountStatsCost uses upstream model (post-mapping) for matching
- Priority: custom rules → model pricing file (when toggle on) → default
- Custom rules always configurable, independent of toggle
- Account ID field changed to searchable selector filtered by platform
- Description updated to reflect new behavior

Quota notification cache fix:
- CheckAccountQuotaAfterIncrement fetches real-time account from DB
- Reconstructs pre-increment usage for accurate threshold crossing detection
- New AccountQuotaReader interface (minimal: GetByID only)

Usage tooltip:
- Per-request/image billing shows per-request price instead of $0 token price
- Token billing continues to show input/output price per million tokens
2026-04-14 09:26:08 +08:00
erio
11c4606874 fix(channel): use upstream model for account stats pricing and remove channel pricing fallback
- resolveAccountStatsCost now uses the final upstream model (after
  account-level mapping) to match custom pricing rules, fixing the
  issue where requested model (e.g. claude-sonnet-4-5) didn't match
  rules configured for upstream model (e.g. claude-opus-4-6)
- Remove tryChannelPricing fallback — only custom rules are applied,
  unmatched requests use default formula (total_cost × rate)
- Remove unused billingService and serviceTier parameters
- Update description: "启用后将支持自定义账号统计的模型价格"
2026-04-14 09:26:08 +08:00
erio
95f9b27e70 fix(notify): add verification flow for saved unverified emails
- Add "verify" button next to saved unverified emails in
  ProfileBalanceNotifyCard (send code → enter code → verify)
- Backend: VerifyAndAddNotifyEmail now marks existing unverified
  emails as verified instead of returning "already exists"
- Inline verification UI with countdown timer and resend button
2026-04-14 09:26:08 +08:00
erio
31550a2c6a fix(notify): use real-time balance for crossing detection and simplify email logic
- Fix cached balance causing threshold crossing to never trigger:
  read real-time balance from billingCacheService instead of stale
  API key auth snapshot
- Remove email="" placeholder concept; all emails are user-managed
- Only send notifications to verified && non-disabled emails
- Frontend: pre-fill user's email in add input when list is empty
- Remove FilterEnabledEmails/IsPrimaryDisabled helpers (no longer needed)
2026-04-14 09:26:07 +08:00
erio
915b7a4a56 feat(notify): convert email lists to NotifyEmailEntry struct with toggle support
- Change balance_notify_extra_emails and account_quota_notify_emails
  from []string to []NotifyEmailEntry{email, disabled, verified}
- Add per-email enable/disable toggle for both user and admin notifications
- Add PUT /user/notify-email/toggle API endpoint
- Fix critical bug: API key auth cache snapshot missing balance notify
  fields (Email, Username, BalanceNotifyEnabled, etc.), causing
  notifications to never fire on cached request paths
- Bump cache snapshot version 3→4 to invalidate stale entries
- Add SQL migration 104 to convert old format data
- Backward compatible: parseNotifyEmails auto-detects old/new format
- User balance notify: max 3 emails (primary + 2 extra)
- Admin quota notify: unlimited emails, each with toggle
2026-04-14 09:26:07 +08:00
erio
61aa197b0b fix(notify): add explicit save button for balance threshold
Replace blur-based auto-save with an explicit Save button so users
know when their threshold is persisted. Shows success toast on save.
2026-04-14 09:26:07 +08:00
erio
422807514c fix(notify): add duplicate email check message and improve extra email UX 2026-04-14 09:26:07 +08:00
erio
81287e960c feat(notify): improve balance notify card UX
- Show system default threshold as placeholder in custom threshold input
- Display user's primary email with "Primary" badge
- Support adding multiple pending emails before verification
- Each pending email has independent send/verify/resend flow
- Expose balance_low_notify_threshold in PublicSettings API
- Clean up timers on unmount to prevent leaks
2026-04-14 09:25:50 +08:00
erio
79d154ed73 fix(notify): add balance/quota notify flags to PublicSettings DTO and handler
The service layer correctly populated BalanceLowNotifyEnabled and
AccountQuotaNotifyEnabled in PublicSettings, but the handler-to-DTO
mapping was missing. Users could not see the balance notify card because
the public settings API never returned these flags.
2026-04-14 09:25:50 +08:00
erio
49281bbe45 fix(websearch): hide show/copy buttons when API key is empty
Only show the inline eye/copy buttons when provider.api_key has a value.
When only api_key_configured is true (saved key, not loaded), buttons are
hidden since there's nothing to show/copy.
2026-04-14 09:25:50 +08:00
erio
5df7309979 fix(websearch): add 15s timeout for admin test search 2026-04-14 09:25:49 +08:00
erio
80fa484467 refactor(channels): move account stats pricing rules from basic to platform tabs
- Basic settings now only shows the global toggle
- Custom pricing rules appear inside each platform tab when toggle is on
- Group selector in rules scoped to the current platform's groups
- Remove unused allFormGroupIds computed
2026-04-14 09:25:49 +08:00
erio
4e96a6faec fix: address audit findings for notify, websearch and security
- Fix GetByKeyForAuth missing user.FieldEmail and user.FieldUsername (notifications sent to empty address)
- Guard against empty email in collectBalanceNotifyRecipients
- Remove non-atomic TotalRecharged read-modify-write in admin balance adjustment
- HTML-escape userName/siteName/accountName in notification email templates
- Fix timer leak in ProfileBalanceNotifyCard (add onUnmounted cleanup)
- Add warning log on websearch proxy URL resolution failure
2026-04-14 09:25:49 +08:00
erio
eba289a7ff feat(notify): add global toggles, percentage threshold, and visibility control
- Add global toggle for account quota notification in admin settings
- Add percentage-based threshold type for per-account quota alerts
- Hide balance notify card on user profile when global toggle is off
- Expose balance_low_notify_enabled and account_quota_notify_enabled in PublicSettings
- Add threshold type (fixed/percentage) to QuotaNotifyToggle with $ / % switcher
2026-04-14 09:25:49 +08:00
erio
889b5b4f3b fix(websearch): improve settings UI and hide config when globally disabled
- API Key show/copy buttons moved inside input field (inline icons)
- Proxy selector and test button on same row to save vertical space
- Test opens a dialog modal instead of inline display
- Hide all websearch config in channels/accounts when global toggle is off
2026-04-14 09:25:36 +08:00
erio
cef22c70ab fix(notify): remove percentage threshold from balance notification
Balance low notification only supports fixed USD amount threshold.
Percentage threshold is a quota concept, not applicable to balance.
Reverted threshold_type from admin settings, user profile, and all
backend/frontend layers. DB fields (balance_notify_threshold_type,
total_recharged) retained for potential future quota use.
2026-04-14 09:25:12 +08:00
erio
9e33d0c4c0 fix: address audit findings for websearch and balance notification
- Fix GetByKeyForAuth not selecting balance notify fields (notifications
  never triggered in gateway path)
- Fix provider-level ProxyURL never resolved: inject ProxyRepository into
  SettingService, resolve proxy URLs when building Manager
- Fix admin manual balance adjustment not updating total_recharged
- Add threshold_type input validation (reject invalid values)
- Fix user threshold_type inheritance: custom threshold defaults to "fixed"
  instead of inheriting global type (prevents $5 being treated as 5%)
- Add try-catch for clipboard.writeText (fails on non-HTTPS)
- Add SetTotalRecharged to user Update for admin balance operations
2026-04-14 09:24:58 +08:00
erio
f694afbbf4 feat(notify): add percentage threshold type for balance low notification
- Add threshold_type field (fixed/percentage) to system and user settings
- Add total_recharged field to users table, auto-incremented on balance credit
- Percentage mode: effective threshold = total_recharged × percentage / 100
- User-level threshold_type inherits from system default when not set
- Update admin settings UI with radio selector (fixed amount / percentage)
- Migration: 102_add_balance_notify_threshold_type.sql
2026-04-14 09:24:17 +08:00
erio
d0674e0ff9 feat(websearch): settings UI overhaul and quota improvements
- Remove Priority field, auto load-balance by quota remaining
- Replace QuotaRefreshInterval (daily/weekly/monthly) with SubscribedAt
  (subscription date, monthly lazy refresh via Redis TTL)
- Add collapsible provider cards, API key show/copy, usage progress bar
- Add test endpoint (POST /web-search-emulation/test) bypassing quota
- Wire WebSearchManagerBuilder on startup (was never called before)
- Fix nextMonthlyReset day-of-month overflow (Jan 31 → Feb 28)
- Fix non-deterministic sort in selectByQuotaWeight
- Map ProxyID in builder for provider-level proxy tracking
- Fix frontend timezone drift in subscribed_at date picker
- Fix provider deletion index shift for expandedProviders state
2026-04-14 09:23:40 +08:00
erio
30b926add4 fix(notify): per-recipient timeout and return user on email removal
- Use per-recipient context timeout in sendEmails to prevent later
  recipients from failing due to shared timeout exhaustion
- Return updated user object from RemoveNotifyEmail handler for
  frontend state consistency (matching VerifyNotifyEmail pattern)
2026-04-14 09:23:16 +08:00
erio
c3812ce1e3 fix(notify): address review findings - accountCost formula, dedup, refactor
- Fix accountCost calculation in finalizePostUsageBilling to match
  postUsageBilling (always multiply by AccountRateMultiplier)
- Use strings.EqualFold for email dedup in collectBalanceNotifyRecipients
- Extract CheckAccountQuotaAfterIncrement into smaller functions:
  buildQuotaDims + asyncSendQuotaAlert (< 30 lines each)
- Add "not splittable" comments for HTML template functions
- Extract QuotaNotifyToggle.vue sub-component to reduce
  QuotaLimitCard.vue from 404 to 339 lines
2026-04-14 09:23:16 +08:00
erio
b32d1a2c9f feat(notify): add balance low & account quota notification system
- User balance low notification: email alert when balance drops below
  configurable threshold (user email + verified extra emails)
- Account quota notification: broadcast email to admin-configured
  recipients when daily/weekly/total quota usage exceeds alert threshold
- Admin settings: global enable/disable, default threshold, quota
  notification email list (Email Settings tab)
- User profile: enable/disable, custom threshold, add/remove extra
  notification emails with verification code flow
- Account quota: per-dimension alert toggle and threshold in quota
  control card
- Trigger logic: first-crossing only (old >= threshold && new < threshold
  for balance; old < threshold && new >= threshold for quota), naturally
  prevents duplicate notifications without Redis dedup
2026-04-14 09:23:02 +08:00
erio
60b0fa81ec fix(websearch): improve isProxyError detection and add manager tests
- Add TLS error detection to isProxyError (RecordHeaderError, handshake)
- Case-insensitive error string matching
- Add 19 unit tests for: isProviderAvailable, resolveProxyID,
  isProxyError, isProxyAvailable, selectByQuotaWeight, newHTTPClient
2026-04-14 09:22:26 +08:00
erio
499159870c fix: gofmt websearch manager 2026-04-14 09:22:25 +08:00
erio
fda61b067c feat(websearch): proxy failover, timeout, quota-weighted load balancing
- Use proxyutil.ConfigureTransportProxy for unified proxy protocol support
  (HTTP/HTTPS/SOCKS5/SOCKS5H), replacing ad-hoc HTTP-only proxy code
- Proxy errors return ErrProxyUnavailable → gateway triggers account switch
  via UpstreamFailoverError instead of fallback to direct connection
- Timeout: proxy dial 3s, TLS handshake 3s, data transfer 60s
- Mark proxy unavailable for 5 minutes in Redis on connectivity failure
- Quota-weighted load balancing: providers with quota_limit>0 are selected
  by remaining quota (weighted random); quota_limit=0 providers treated as
  0% weight and placed last
2026-04-14 09:22:25 +08:00
erio
7535e312e0 feat(channels): add custom account stats pricing rules
Allow channels to configure independent model pricing for account
statistics cost calculation, decoupled from user billing.

Backend:
- Migration 101: channels.apply_pricing_to_account_stats toggle,
  channel_account_stats_pricing_rules/model_pricing tables,
  usage_logs.account_stats_cost column
- resolveAccountStatsCost: match rules by group/account, then channel
  pricing, fallback to original formula when unconfigured
- Integrate into both GatewayService.recordUsageCore and
  OpenAIGatewayService.RecordUsage
- Update 8 account stats SQL queries to use
  COALESCE(account_stats_cost, total_cost) * account_rate_multiplier
- 23 unit tests for matching, pricing lookup, and cost calculation

Frontend:
- Channel edit dialog: toggle + custom rules UI with group/account
  multi-select and pricing entry cards
- API types and i18n (zh/en)
2026-04-14 09:22:12 +08:00
erio
7fad9f604f fix(test): add web_search_emulation_enabled to API contract test
The settings API response now includes the new field; update the
expected snapshot in TestAPIContracts to match.
2026-04-14 09:21:28 +08:00
erio
1b53ffcac7 feat(gateway): add web search emulation for Anthropic API Key accounts
Inject web search capability for Claude Console (API Key) accounts that
don't natively support Anthropic's web_search tool. When a pure
web_search request is detected, the gateway calls Brave Search or Tavily
API directly and constructs an Anthropic-protocol-compliant SSE/JSON
response without forwarding to upstream.

Backend:
- New `pkg/websearch/` SDK: Brave and Tavily provider implementations
  with io.LimitReader, proxy support, and Redis-based quota tracking
  (Lua atomic INCR + TTL, DECR rollback on failure)
- Global config via `settings.web_search_emulation_config` (JSON) with
  in-process cache + singleflight, input validation, API key merge on
  save, and sanitized API responses
- Channel-level toggle via `channels.features_config` JSONB column
  (DB migration 101)
- Account-level toggle via `accounts.extra.web_search_emulation`
- Request interception in `Forward()` with SSE streaming response
  construction using json.Marshal (no manual string concatenation)
- Manager hot-reload: `RebuildWebSearchManager()` called on config save
  and startup via `SetWebSearchRedisClient()`
- 70 unit tests covering providers, manager, config validation,
  sanitization, tool detection, query extraction, and response building

Frontend:
- Settings → Gateway tab: Web Search Emulation config card with global
  toggle, provider list (add/remove, API key, priority, quota, proxy)
- Channels → Anthropic tab: web search emulation toggle with global
  state linkage (disabled when global off)
- Account Create/Edit modals: web search emulation toggle for API Key
  type with Toggle component
- Full i18n coverage (zh + en)
2026-04-14 09:20:39 +08:00
erio
c738cfec93 fix(payment): critical audit fixes for security, idempotency and correctness
Backend fixes:
- #1: doSub subscription idempotency via audit log check
- #2: markFailed only when status=RECHARGING (prevents overwriting COMPLETED)
- #3: ExpireTimedOutOrders checks upstream payment before expiring
- #4: Public verify endpoint for payment result page (no auth required)
- #5: EasyPay QueryOrder returns amount, confirmPayment handles zero amount
- #6: WxPay notifyUrl priority: request-first, config-fallback
- #7: EasyPay remove double URL decode in VerifyNotification
- #8: checkPaid/cancelUpstreamPayment use order's provider instance
- #9: Amount NaN/Inf/negative validation in order creation and refund
- #10: Refund amount comparison uses tolerance instead of float64 ==
- #11: Skip balance deduction on retry when previous rollback failed
- #12: checkPaid logs fulfillment errors instead of silently ignoring
- #13: WxPay certSerial added to required config fields

Frontend fixes:
- Payment result page no longer requires authentication
- Public verify API fallback for expired sessions
2026-04-14 09:19:33 +08:00
erio
56e4a9a914 fix: audit fixes - magic strings to constants, frontend any/catch, LB tests
Backend:
- Define OrderTypeBalance/Subscription, EntityStatusActive, DeductionType*,
  NotificationStatus* constants in payment/types.go
- Replace all magic strings in payment_order, payment_fulfillment, payment_refund
- Add local constants in easypay.go (tradeStatusSuccess, signTypeMD5)
- Add 27 unit tests for load balancer (filterByLimits, pickLeastAmount,
  getInstanceChannelLimits, startOfDay)

Frontend:
- Remove all `any` types in SettingsView.vue (18 catch blocks + 1 payload)
- Fix bare catch blocks in PaymentResultView, PaymentView
- Add `unknown` type annotation to all catch blocks

chore: bump version to 0.1.108.140
2026-04-14 09:18:58 +08:00
erio
3c884f8e30 test(payment): add unit tests for payment audit fixes + allow empty supported_types
Tests (1033 new lines, 100% coverage on modified functions):
- amount.go: YuanToFen/FenToYuan with precision edge cases
- wxpay: mapWxState, wxSV, formatPEM, NewWxpay validation
- alipay: isTradeNotExist, NewAlipay validation
- webhook: writeSuccessResponse (wxpay JSON, stripe empty, others text)
- config: validateProviderRequest, isSensitiveConfigField, joinTypes
- fulfillment: resolveRedeemAction idempotency logic

Business logic changes:
- Allow empty supported_types on provider instances
- Block removing payment types when instance has pending orders
- Extract resolveRedeemAction as testable pure function
2026-04-14 09:18:22 +08:00
erio
5bae3b0577 fix(payment): audit fixes for alipay/wxpay/stripe payment providers
Backend:
- Extract YuanToFen/FenToYuan to payment/amount.go using shopspring/decimal
- Require alipay publicKey in config validation
- Fix wxpay webhook response to return JSON per V3 spec
- Remove wxpay certSerial fallback to publicKeyId
- Define magic strings as named constants in wxpay/alipay providers
- Add slog warning for wxpay H5→Native payment downgrade
- Make EncryptionKey validation return error on invalid (non-empty) key
- Make decryptConfig propagate errors instead of returning nil
- Add idempotency check in doBalance to prevent stuck FAILED retries

Frontend:
- Fix dashboard currency symbol from $ to ¥
- Fix AdminPaymentPlansView any type to proper SubscriptionPlan type
- Make quick amount buttons follow selected payment method limits
- Center help image with larger height and text below
2026-04-14 09:17:06 +08:00
erio
1c63ea1448 fix(channel): add missing features column to List query
The paginated List query was selecting 9 columns but scanning 10 fields,
missing c.features. GetByID and ListAll already included it correctly.
2026-04-14 09:16:13 +08:00
erio
3d4d960d60 fix: gofmt formatting after merge 2026-04-14 09:15:49 +08:00
erio
794e817208 refactor: remove PaymentChannel, reuse upstream Channel with features field
- Delete payment_channels table and PaymentChannel Ent schema
- Add `features` column to upstream channels table (migration 095)
- Add Features field to Channel struct, input types, handler request/response
- Payment user/admin handlers now use ChannelService directly
- Remove Channel CRUD from PaymentConfigService and admin payment routes
- Remove "渠道管理" tab from admin orders page (use /admin/channels)
2026-04-14 09:15:29 +08:00
erio
37c23eccfe fix: gofmt formatting 2026-04-14 09:14:29 +08:00
erio
e374874125 feat(channel): improve cache strategy and add restriction logging
- Change channel cache TTL from 60s to 10min (reduce unnecessary DB queries)
- Actively rebuild cache after CRUD instead of lazy invalidation
- Add slog.Warn logging for channel pricing restriction blocks (4 places)
2026-04-14 09:13:53 +08:00
erio
160903fce7 fix: address review findings for channel restriction refactoring
- Fix 7 stale comments still mentioning "限制检查" in handlers/services
- Make billingModelForRestriction explicitly list channel_mapped case
- Add slog.Warn for error swallowing in ResolveChannelMapping and
  needsUpstreamChannelRestrictionCheck
- Document sticky session upstream check exemption
2026-04-14 09:12:42 +08:00
erio
2dce4306b4 refactor: move channel model restriction from handler to scheduling phase
Move the model pricing restriction check from 8 handler entry points
to the account scheduling phase (SelectAccountForModelWithExclusions /
SelectAccountWithLoadAwareness), aligning restriction with billing:

- requested: check original request model against pricing list
- channel_mapped: check channel-mapped model against pricing list
- upstream: per-account check using account-mapped model

Handler layer now only resolves channel mapping (no restriction).
Scheduling layer performs pre-check for requested/channel_mapped,
and per-account filtering for upstream billing source.
2026-04-14 09:12:29 +08:00
erio
3de7713017 fix(channel): splice替换model_pricing条目 + 增强调试日志 2026-04-14 09:08:58 +08:00
erio
1cd033e521 style: apply gofmt formatting
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-14 09:08:00 +08:00
github-actions[bot]
e534e9bae8 chore: sync VERSION to 0.1.112 [skip ci] 2026-04-13 15:24:14 +00:00
214 changed files with 14387 additions and 1644 deletions

View File

@@ -17,6 +17,7 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.2'
@@ -36,6 +37,7 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.2'

1
Antigravity-Manager Submodule

Submodule Antigravity-Manager added at a9d96bd549

View File

@@ -1 +1 @@
0.1.111
0.1.112

View File

@@ -36,19 +36,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// Business layer ProviderSets
repository.ProviderSet,
service.ProviderSet,
payment.ProviderSet,
middleware.ProviderSet,
handler.ProviderSet,
// Server layer ProviderSet
server.ProviderSet,
// Payment providers
payment.ProvideRegistry,
payment.ProvideEncryptionKey,
payment.ProvideDefaultLoadBalancer,
service.ProvidePaymentConfigService,
service.ProvidePaymentOrderExpiryService,
// Privacy client factory for OpenAI training opt-out
providePrivacyClientFactory,

View File

@@ -50,8 +50,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
channelRepository := repository.NewChannelRepository(db)
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
proxyRepository := repository.NewProxyRepository(client, db)
settingService := service.ProvideSettingService(settingRepository, groupRepository, proxyRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
@@ -65,23 +65,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
apiKeyService.SetRateLimitCacheInvalidator(billingCache)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
registry := payment.ProvideRegistry()
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
if err != nil {
return nil, err
}
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
@@ -89,10 +79,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService)
userHandler := handler.NewUserHandler(userService, emailService, emailCache)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -112,7 +101,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
@@ -120,11 +108,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
@@ -134,7 +125,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
@@ -142,23 +132,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
@@ -175,6 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -183,17 +171,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@@ -221,8 +210,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
adminPaymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, adminPaymentHandler)
registry := payment.ProvideRegistry()
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
if err != nil {
return nil, err
}
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -245,7 +244,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)

View File

@@ -616,6 +616,7 @@ var (
{Name: "sort_order", Type: field.TypeInt, Default: 0},
{Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "refund_enabled", Type: field.TypeBool, Default: false},
{Name: "allow_user_refund", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
}
@@ -1078,6 +1079,11 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
{Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
{Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{

View File

@@ -15642,25 +15642,26 @@ func (m *PaymentOrderMutation) ResetEdge(name string) error {
// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph.
type PaymentProviderInstanceMutation struct {
config
op Op
typ string
id *int64
provider_key *string
name *string
_config *string
supported_types *string
enabled *bool
payment_mode *string
sort_order *int
addsort_order *int
limits *string
refund_enabled *bool
created_at *time.Time
updated_at *time.Time
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*PaymentProviderInstance, error)
predicates []predicate.PaymentProviderInstance
op Op
typ string
id *int64
provider_key *string
name *string
_config *string
supported_types *string
enabled *bool
payment_mode *string
sort_order *int
addsort_order *int
limits *string
refund_enabled *bool
allow_user_refund *bool
created_at *time.Time
updated_at *time.Time
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*PaymentProviderInstance, error)
predicates []predicate.PaymentProviderInstance
}
var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil)
@@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() {
m.refund_enabled = nil
}
// SetAllowUserRefund sets the "allow_user_refund" field.
func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) {
m.allow_user_refund = &b
}
// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation.
func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) {
v := m.allow_user_refund
if v == nil {
return
}
return *v, true
}
// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity.
// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAllowUserRefund requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err)
}
return oldValue.AllowUserRefund, nil
}
// ResetAllowUserRefund resets all changes to the "allow_user_refund" field.
func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() {
m.allow_user_refund = nil
}
// SetCreatedAt sets the "created_at" field.
func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
m.created_at = &t
@@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *PaymentProviderInstanceMutation) Fields() []string {
fields := make([]string, 0, 11)
fields := make([]string, 0, 12)
if m.provider_key != nil {
fields = append(fields, paymentproviderinstance.FieldProviderKey)
}
@@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string {
if m.refund_enabled != nil {
fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
}
if m.allow_user_refund != nil {
fields = append(fields, paymentproviderinstance.FieldAllowUserRefund)
}
if m.created_at != nil {
fields = append(fields, paymentproviderinstance.FieldCreatedAt)
}
@@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
return m.Limits()
case paymentproviderinstance.FieldRefundEnabled:
return m.RefundEnabled()
case paymentproviderinstance.FieldAllowUserRefund:
return m.AllowUserRefund()
case paymentproviderinstance.FieldCreatedAt:
return m.CreatedAt()
case paymentproviderinstance.FieldUpdatedAt:
@@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str
return m.OldLimits(ctx)
case paymentproviderinstance.FieldRefundEnabled:
return m.OldRefundEnabled(ctx)
case paymentproviderinstance.FieldAllowUserRefund:
return m.OldAllowUserRefund(ctx)
case paymentproviderinstance.FieldCreatedAt:
return m.OldCreatedAt(ctx)
case paymentproviderinstance.FieldUpdatedAt:
@@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value)
}
m.SetRefundEnabled(v)
return nil
case paymentproviderinstance.FieldAllowUserRefund:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAllowUserRefund(v)
return nil
case paymentproviderinstance.FieldCreatedAt:
v, ok := value.(time.Time)
if !ok {
@@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error {
case paymentproviderinstance.FieldRefundEnabled:
m.ResetRefundEnabled()
return nil
case paymentproviderinstance.FieldAllowUserRefund:
m.ResetAllowUserRefund()
return nil
case paymentproviderinstance.FieldCreatedAt:
m.ResetCreatedAt()
return nil
@@ -28210,6 +28264,13 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
balance_notify_enabled *bool
balance_notify_threshold_type *string
balance_notify_threshold *float64
addbalance_notify_threshold *float64
balance_notify_extra_emails *string
total_recharged *float64
addtotal_recharged *float64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -28927,6 +28988,240 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
m.balance_notify_enabled = &b
}
// BalanceNotifyEnabled returns the value of the "balance_notify_enabled" field in the mutation.
func (m *UserMutation) BalanceNotifyEnabled() (r bool, exists bool) {
v := m.balance_notify_enabled
if v == nil {
return
}
return *v, true
}
// OldBalanceNotifyEnabled returns the old "balance_notify_enabled" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldBalanceNotifyEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBalanceNotifyEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBalanceNotifyEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBalanceNotifyEnabled: %w", err)
}
return oldValue.BalanceNotifyEnabled, nil
}
// ResetBalanceNotifyEnabled resets all changes to the "balance_notify_enabled" field.
func (m *UserMutation) ResetBalanceNotifyEnabled() {
m.balance_notify_enabled = nil
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (m *UserMutation) SetBalanceNotifyThresholdType(s string) {
m.balance_notify_threshold_type = &s
}
// BalanceNotifyThresholdType returns the value of the "balance_notify_threshold_type" field in the mutation.
func (m *UserMutation) BalanceNotifyThresholdType() (r string, exists bool) {
v := m.balance_notify_threshold_type
if v == nil {
return
}
return *v, true
}
// OldBalanceNotifyThresholdType returns the old "balance_notify_threshold_type" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldBalanceNotifyThresholdType(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBalanceNotifyThresholdType is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBalanceNotifyThresholdType requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBalanceNotifyThresholdType: %w", err)
}
return oldValue.BalanceNotifyThresholdType, nil
}
// ResetBalanceNotifyThresholdType resets all changes to the "balance_notify_threshold_type" field.
func (m *UserMutation) ResetBalanceNotifyThresholdType() {
m.balance_notify_threshold_type = nil
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (m *UserMutation) SetBalanceNotifyThreshold(f float64) {
m.balance_notify_threshold = &f
m.addbalance_notify_threshold = nil
}
// BalanceNotifyThreshold returns the value of the "balance_notify_threshold" field in the mutation.
func (m *UserMutation) BalanceNotifyThreshold() (r float64, exists bool) {
v := m.balance_notify_threshold
if v == nil {
return
}
return *v, true
}
// OldBalanceNotifyThreshold returns the old "balance_notify_threshold" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldBalanceNotifyThreshold(ctx context.Context) (v *float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBalanceNotifyThreshold is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBalanceNotifyThreshold requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBalanceNotifyThreshold: %w", err)
}
return oldValue.BalanceNotifyThreshold, nil
}
// AddBalanceNotifyThreshold adds f to the "balance_notify_threshold" field.
func (m *UserMutation) AddBalanceNotifyThreshold(f float64) {
if m.addbalance_notify_threshold != nil {
*m.addbalance_notify_threshold += f
} else {
m.addbalance_notify_threshold = &f
}
}
// AddedBalanceNotifyThreshold returns the value that was added to the "balance_notify_threshold" field in this mutation.
func (m *UserMutation) AddedBalanceNotifyThreshold() (r float64, exists bool) {
v := m.addbalance_notify_threshold
if v == nil {
return
}
return *v, true
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (m *UserMutation) ClearBalanceNotifyThreshold() {
m.balance_notify_threshold = nil
m.addbalance_notify_threshold = nil
m.clearedFields[user.FieldBalanceNotifyThreshold] = struct{}{}
}
// BalanceNotifyThresholdCleared returns if the "balance_notify_threshold" field was cleared in this mutation.
func (m *UserMutation) BalanceNotifyThresholdCleared() bool {
_, ok := m.clearedFields[user.FieldBalanceNotifyThreshold]
return ok
}
// ResetBalanceNotifyThreshold resets all changes to the "balance_notify_threshold" field.
func (m *UserMutation) ResetBalanceNotifyThreshold() {
m.balance_notify_threshold = nil
m.addbalance_notify_threshold = nil
delete(m.clearedFields, user.FieldBalanceNotifyThreshold)
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (m *UserMutation) SetBalanceNotifyExtraEmails(s string) {
m.balance_notify_extra_emails = &s
}
// BalanceNotifyExtraEmails returns the value of the "balance_notify_extra_emails" field in the mutation.
func (m *UserMutation) BalanceNotifyExtraEmails() (r string, exists bool) {
v := m.balance_notify_extra_emails
if v == nil {
return
}
return *v, true
}
// OldBalanceNotifyExtraEmails returns the old "balance_notify_extra_emails" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldBalanceNotifyExtraEmails(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBalanceNotifyExtraEmails is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBalanceNotifyExtraEmails requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBalanceNotifyExtraEmails: %w", err)
}
return oldValue.BalanceNotifyExtraEmails, nil
}
// ResetBalanceNotifyExtraEmails resets all changes to the "balance_notify_extra_emails" field.
func (m *UserMutation) ResetBalanceNotifyExtraEmails() {
m.balance_notify_extra_emails = nil
}
// SetTotalRecharged sets the "total_recharged" field.
func (m *UserMutation) SetTotalRecharged(f float64) {
m.total_recharged = &f
m.addtotal_recharged = nil
}
// TotalRecharged returns the value of the "total_recharged" field in the mutation.
func (m *UserMutation) TotalRecharged() (r float64, exists bool) {
v := m.total_recharged
if v == nil {
return
}
return *v, true
}
// OldTotalRecharged returns the old "total_recharged" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTotalRecharged(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTotalRecharged is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTotalRecharged requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTotalRecharged: %w", err)
}
return oldValue.TotalRecharged, nil
}
// AddTotalRecharged adds f to the "total_recharged" field.
func (m *UserMutation) AddTotalRecharged(f float64) {
if m.addtotal_recharged != nil {
*m.addtotal_recharged += f
} else {
m.addtotal_recharged = &f
}
}
// AddedTotalRecharged returns the value that was added to the "total_recharged" field in this mutation.
func (m *UserMutation) AddedTotalRecharged() (r float64, exists bool) {
v := m.addtotal_recharged
if v == nil {
return
}
return *v, true
}
// ResetTotalRecharged resets all changes to the "total_recharged" field.
func (m *UserMutation) ResetTotalRecharged() {
m.total_recharged = nil
m.addtotal_recharged = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -29501,7 +29796,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 14)
fields := make([]string, 0, 19)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29544,6 +29839,21 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
if m.balance_notify_enabled != nil {
fields = append(fields, user.FieldBalanceNotifyEnabled)
}
if m.balance_notify_threshold_type != nil {
fields = append(fields, user.FieldBalanceNotifyThresholdType)
}
if m.balance_notify_threshold != nil {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
if m.balance_notify_extra_emails != nil {
fields = append(fields, user.FieldBalanceNotifyExtraEmails)
}
if m.total_recharged != nil {
fields = append(fields, user.FieldTotalRecharged)
}
return fields
}
@@ -29580,6 +29890,16 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
case user.FieldBalanceNotifyEnabled:
return m.BalanceNotifyEnabled()
case user.FieldBalanceNotifyThresholdType:
return m.BalanceNotifyThresholdType()
case user.FieldBalanceNotifyThreshold:
return m.BalanceNotifyThreshold()
case user.FieldBalanceNotifyExtraEmails:
return m.BalanceNotifyExtraEmails()
case user.FieldTotalRecharged:
return m.TotalRecharged()
}
return nil, false
}
@@ -29617,6 +29937,16 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
case user.FieldBalanceNotifyEnabled:
return m.OldBalanceNotifyEnabled(ctx)
case user.FieldBalanceNotifyThresholdType:
return m.OldBalanceNotifyThresholdType(ctx)
case user.FieldBalanceNotifyThreshold:
return m.OldBalanceNotifyThreshold(ctx)
case user.FieldBalanceNotifyExtraEmails:
return m.OldBalanceNotifyExtraEmails(ctx)
case user.FieldTotalRecharged:
return m.OldTotalRecharged(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -29724,6 +30054,41 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
case user.FieldBalanceNotifyEnabled:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBalanceNotifyEnabled(v)
return nil
case user.FieldBalanceNotifyThresholdType:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBalanceNotifyThresholdType(v)
return nil
case user.FieldBalanceNotifyThreshold:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBalanceNotifyThreshold(v)
return nil
case user.FieldBalanceNotifyExtraEmails:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBalanceNotifyExtraEmails(v)
return nil
case user.FieldTotalRecharged:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTotalRecharged(v)
return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -29738,6 +30103,12 @@ func (m *UserMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, user.FieldConcurrency)
}
if m.addbalance_notify_threshold != nil {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
if m.addtotal_recharged != nil {
fields = append(fields, user.FieldTotalRecharged)
}
return fields
}
@@ -29750,6 +30121,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalance()
case user.FieldConcurrency:
return m.AddedConcurrency()
case user.FieldBalanceNotifyThreshold:
return m.AddedBalanceNotifyThreshold()
case user.FieldTotalRecharged:
return m.AddedTotalRecharged()
}
return nil, false
}
@@ -29773,6 +30148,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
case user.FieldBalanceNotifyThreshold:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddBalanceNotifyThreshold(v)
return nil
case user.FieldTotalRecharged:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddTotalRecharged(v)
return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
@@ -29790,6 +30179,9 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
return fields
}
@@ -29813,6 +30205,9 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
case user.FieldBalanceNotifyThreshold:
m.ClearBalanceNotifyThreshold()
return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -29863,6 +30258,21 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
case user.FieldBalanceNotifyEnabled:
m.ResetBalanceNotifyEnabled()
return nil
case user.FieldBalanceNotifyThresholdType:
m.ResetBalanceNotifyThresholdType()
return nil
case user.FieldBalanceNotifyThreshold:
m.ResetBalanceNotifyThreshold()
return nil
case user.FieldBalanceNotifyExtraEmails:
m.ResetBalanceNotifyExtraEmails()
return nil
case user.FieldTotalRecharged:
m.ResetTotalRecharged()
return nil
}
return fmt.Errorf("unknown User field %s", name)
}

View File

@@ -35,6 +35,8 @@ type PaymentProviderInstance struct {
Limits string `json:"limits,omitempty"`
// RefundEnabled holds the value of the "refund_enabled" field.
RefundEnabled bool `json:"refund_enabled,omitempty"`
// AllowUserRefund holds the value of the "allow_user_refund" field.
AllowUserRefund bool `json:"allow_user_refund,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
@@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled:
case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund:
values[i] = new(sql.NullBool)
case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder:
values[i] = new(sql.NullInt64)
@@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any)
} else if value.Valid {
_m.RefundEnabled = value.Bool
}
case paymentproviderinstance.FieldAllowUserRefund:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i])
} else if value.Valid {
_m.AllowUserRefund = value.Bool
}
case paymentproviderinstance.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
@@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string {
builder.WriteString("refund_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled))
builder.WriteString(", ")
builder.WriteString("allow_user_refund=")
builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund))
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")

View File

@@ -31,6 +31,8 @@ const (
FieldLimits = "limits"
// FieldRefundEnabled holds the string denoting the refund_enabled field in the database.
FieldRefundEnabled = "refund_enabled"
// FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database.
FieldAllowUserRefund = "allow_user_refund"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
@@ -51,6 +53,7 @@ var Columns = []string{
FieldSortOrder,
FieldLimits,
FieldRefundEnabled,
FieldAllowUserRefund,
FieldCreatedAt,
FieldUpdatedAt,
}
@@ -88,6 +91,8 @@ var (
DefaultLimits string
// DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field.
DefaultRefundEnabled bool
// DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field.
DefaultAllowUserRefund bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
@@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc()
}
// ByAllowUserRefund orders the results by the allow_user_refund field.
func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()

View File

@@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v))
}
// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ.
func AllowUserRefund(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
@@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v))
}
// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field.
func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
}
// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field.
func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))

View File

@@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym
return _c
}
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate {
_c.mutation.SetAllowUserRefund(v)
return _c
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate {
if v != nil {
_c.SetAllowUserRefund(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field.
func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate {
_c.mutation.SetCreatedAt(v)
@@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() {
v := paymentproviderinstance.DefaultRefundEnabled
_c.mutation.SetRefundEnabled(v)
}
if _, ok := _c.mutation.AllowUserRefund(); !ok {
v := paymentproviderinstance.DefaultAllowUserRefund
_c.mutation.SetAllowUserRefund(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok {
v := paymentproviderinstance.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
@@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error {
if _, ok := _c.mutation.RefundEnabled(); !ok {
return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)}
}
if _, ok := _c.mutation.AllowUserRefund(); !ok {
return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)}
}
@@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance,
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
_node.RefundEnabled = value
}
if value, ok := _c.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
_node.AllowUserRefund = value
}
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
@@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn
return u
}
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert {
u.Set(paymentproviderinstance.FieldAllowUserRefund, v)
return u
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert {
u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund)
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert {
u.Set(paymentproviderinstance.FieldUpdatedAt, v)
@@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide
})
}
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.SetAllowUserRefund(v)
})
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.UpdateAllowUserRefund()
})
}
// SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
@@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid
})
}
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.SetAllowUserRefund(v)
})
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.UpdateAllowUserRefund()
})
}
// SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) {

View File

@@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym
return _u
}
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate {
_u.mutation.SetAllowUserRefund(v)
return _u
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate {
if v != nil {
_u.SetAllowUserRefund(*v)
}
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate {
_u.mutation.SetUpdatedAt(v)
@@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int
if value, ok := _u.mutation.RefundEnabled(); ok {
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
}
@@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P
return _u
}
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne {
_u.mutation.SetAllowUserRefund(v)
return _u
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne {
if v != nil {
_u.SetAllowUserRefund(*v)
}
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne {
_u.mutation.SetUpdatedAt(v)
@@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node
if value, ok := _u.mutation.RefundEnabled(); ok {
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
}

View File

@@ -668,12 +668,16 @@ func init() {
paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor()
// paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field.
paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool)
// paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field.
paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor()
// paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field.
paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool)
// paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field.
paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor()
paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor()
// paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time)
// paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field.
paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor()
paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor()
// paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
@@ -1293,6 +1297,22 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
// userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
userDescBalanceNotifyEnabled := userFields[11].Descriptor()
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
// userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
// user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
// userDescTotalRecharged is the schema descriptor for total_recharged field.
userDescTotalRecharged := userFields[15].Descriptor()
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.

View File

@@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field {
Default(""),
field.Bool("refund_enabled").
Default(false),
field.Bool("allow_user_refund").
Default(false),
field.Time("created_at").
Immutable().
Default(time.Now).

View File

@@ -72,6 +72,22 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
// 余额不足通知
field.Bool("balance_notify_enabled").
Default(true),
field.String("balance_notify_threshold_type").
Default("fixed"), // "fixed" | "percentage"
field.Float("balance_notify_threshold").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Optional().
Nillable(),
field.String("balance_notify_extra_emails").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default("[]"),
field.Float("total_recharged").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
}
}

View File

@@ -45,6 +45,16 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type,omitempty"`
// BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field.
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
// BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field.
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
// TotalRecharged holds the value of the "total_recharged" field.
TotalRecharged float64 `json:"total_recharged,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -184,13 +194,13 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case user.FieldTotpEnabled:
case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled:
values[i] = new(sql.NullBool)
case user.FieldBalance:
case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
@@ -302,6 +312,37 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
} else if value.Valid {
_m.BalanceNotifyEnabled = value.Bool
}
case user.FieldBalanceNotifyThresholdType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_threshold_type", values[i])
} else if value.Valid {
_m.BalanceNotifyThresholdType = value.String
}
case user.FieldBalanceNotifyThreshold:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i])
} else if value.Valid {
_m.BalanceNotifyThreshold = new(float64)
*_m.BalanceNotifyThreshold = value.Float64
}
case user.FieldBalanceNotifyExtraEmails:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_extra_emails", values[i])
} else if value.Valid {
_m.BalanceNotifyExtraEmails = value.String
}
case user.FieldTotalRecharged:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field total_recharged", values[i])
} else if value.Valid {
_m.TotalRecharged = value.Float64
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -440,6 +481,23 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")
builder.WriteString("balance_notify_threshold_type=")
builder.WriteString(_m.BalanceNotifyThresholdType)
builder.WriteString(", ")
if v := _m.BalanceNotifyThreshold; v != nil {
builder.WriteString("balance_notify_threshold=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("balance_notify_extra_emails=")
builder.WriteString(_m.BalanceNotifyExtraEmails)
builder.WriteString(", ")
builder.WriteString("total_recharged=")
builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -43,6 +43,16 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
FieldBalanceNotifyThresholdType = "balance_notify_threshold_type"
// FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database.
FieldBalanceNotifyThreshold = "balance_notify_threshold"
// FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database.
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
// FieldTotalRecharged holds the string denoting the total_recharged field in the database.
FieldTotalRecharged = "total_recharged"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -161,6 +171,11 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
FieldBalanceNotifyExtraEmails,
FieldTotalRecharged,
}
var (
@@ -217,6 +232,14 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
DefaultBalanceNotifyThresholdType string
// DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field.
DefaultBalanceNotifyExtraEmails string
// DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
DefaultTotalRecharged float64
)
// OrderOption defines the ordering options for the User queries.
@@ -297,6 +320,31 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
}
// ByBalanceNotifyThresholdType orders the results by the balance_notify_threshold_type field.
func ByBalanceNotifyThresholdType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyThresholdType, opts...).ToFunc()
}
// ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field.
func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc()
}
// ByBalanceNotifyExtraEmails orders the results by the balance_notify_extra_emails field.
func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc()
}
// ByTotalRecharged orders the results by the total_recharged field.
func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -125,6 +125,31 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
}
// BalanceNotifyThresholdType applies equality check predicate on the "balance_notify_threshold_type" field. It's identical to BalanceNotifyThresholdTypeEQ.
func BalanceNotifyThresholdType(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ.
func BalanceNotifyThreshold(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyExtraEmails applies equality check predicate on the "balance_notify_extra_emails" field. It's identical to BalanceNotifyExtraEmailsEQ.
func BalanceNotifyExtraEmails(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
}
// TotalRecharged applies equality check predicate on the "total_recharged" field. It's identical to TotalRechargedEQ.
func TotalRecharged(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -860,6 +885,236 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
}
// BalanceNotifyEnabledNEQ applies the NEQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledNEQ(v bool) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v))
}
// BalanceNotifyThresholdTypeEQ applies the EQ predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeNEQ applies the NEQ predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeIn applies the In predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldBalanceNotifyThresholdType, vs...))
}
// BalanceNotifyThresholdTypeNotIn applies the NotIn predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThresholdType, vs...))
}
// BalanceNotifyThresholdTypeGT applies the GT predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeGTE applies the GTE predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeLT applies the LT predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeLTE applies the LTE predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeContains applies the Contains predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeHasPrefix applies the HasPrefix predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeHasSuffix applies the HasSuffix predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeEqualFold applies the EqualFold predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeContainsFold applies the ContainsFold predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdEQ(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdNEQ applies the NEQ predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdNEQ(v float64) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdIn applies the In predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdIn(vs ...float64) predicate.User {
return predicate.User(sql.FieldIn(FieldBalanceNotifyThreshold, vs...))
}
// BalanceNotifyThresholdNotIn applies the NotIn predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdNotIn(vs ...float64) predicate.User {
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThreshold, vs...))
}
// BalanceNotifyThresholdGT applies the GT predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdGT(v float64) predicate.User {
return predicate.User(sql.FieldGT(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdGTE applies the GTE predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdGTE(v float64) predicate.User {
return predicate.User(sql.FieldGTE(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdLT applies the LT predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdLT(v float64) predicate.User {
return predicate.User(sql.FieldLT(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdLTE applies the LTE predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdLTE(v float64) predicate.User {
return predicate.User(sql.FieldLTE(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdIsNil applies the IsNil predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldBalanceNotifyThreshold))
}
// BalanceNotifyThresholdNotNil applies the NotNil predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldBalanceNotifyThreshold))
}
// BalanceNotifyExtraEmailsEQ applies the EQ predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsNEQ applies the NEQ predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsIn applies the In predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldBalanceNotifyExtraEmails, vs...))
}
// BalanceNotifyExtraEmailsNotIn applies the NotIn predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyExtraEmails, vs...))
}
// BalanceNotifyExtraEmailsGT applies the GT predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsGTE applies the GTE predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsLT applies the LT predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsLTE applies the LTE predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsContains applies the Contains predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsHasPrefix applies the HasPrefix predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsHasSuffix applies the HasSuffix predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsEqualFold applies the EqualFold predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsContainsFold applies the ContainsFold predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v))
}
// TotalRechargedEQ applies the EQ predicate on the "total_recharged" field.
func TotalRechargedEQ(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
}
// TotalRechargedNEQ applies the NEQ predicate on the "total_recharged" field.
func TotalRechargedNEQ(v float64) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotalRecharged, v))
}
// TotalRechargedIn applies the In predicate on the "total_recharged" field.
func TotalRechargedIn(vs ...float64) predicate.User {
return predicate.User(sql.FieldIn(FieldTotalRecharged, vs...))
}
// TotalRechargedNotIn applies the NotIn predicate on the "total_recharged" field.
func TotalRechargedNotIn(vs ...float64) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotalRecharged, vs...))
}
// TotalRechargedGT applies the GT predicate on the "total_recharged" field.
func TotalRechargedGT(v float64) predicate.User {
return predicate.User(sql.FieldGT(FieldTotalRecharged, v))
}
// TotalRechargedGTE applies the GTE predicate on the "total_recharged" field.
func TotalRechargedGTE(v float64) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotalRecharged, v))
}
// TotalRechargedLT applies the LT predicate on the "total_recharged" field.
func TotalRechargedLT(v float64) predicate.User {
return predicate.User(sql.FieldLT(FieldTotalRecharged, v))
}
// TotalRechargedLTE applies the LTE predicate on the "total_recharged" field.
func TotalRechargedLTE(v float64) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@@ -211,6 +211,76 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v)
return _c
}
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate {
if v != nil {
_c.SetBalanceNotifyEnabled(*v)
}
return _c
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (_c *UserCreate) SetBalanceNotifyThresholdType(v string) *UserCreate {
_c.mutation.SetBalanceNotifyThresholdType(v)
return _c
}
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
func (_c *UserCreate) SetNillableBalanceNotifyThresholdType(v *string) *UserCreate {
if v != nil {
_c.SetBalanceNotifyThresholdType(*v)
}
return _c
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate {
_c.mutation.SetBalanceNotifyThreshold(v)
return _c
}
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
func (_c *UserCreate) SetNillableBalanceNotifyThreshold(v *float64) *UserCreate {
if v != nil {
_c.SetBalanceNotifyThreshold(*v)
}
return _c
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (_c *UserCreate) SetBalanceNotifyExtraEmails(v string) *UserCreate {
_c.mutation.SetBalanceNotifyExtraEmails(v)
return _c
}
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate {
if v != nil {
_c.SetBalanceNotifyExtraEmails(*v)
}
return _c
}
// SetTotalRecharged sets the "total_recharged" field.
func (_c *UserCreate) SetTotalRecharged(v float64) *UserCreate {
_c.mutation.SetTotalRecharged(v)
return _c
}
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
if v != nil {
_c.SetTotalRecharged(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -440,6 +510,22 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
}
if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
v := user.DefaultBalanceNotifyThresholdType
_c.mutation.SetBalanceNotifyThresholdType(v)
}
if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
v := user.DefaultBalanceNotifyExtraEmails
_c.mutation.SetBalanceNotifyExtraEmails(v)
}
if _, ok := _c.mutation.TotalRecharged(); !ok {
v := user.DefaultTotalRecharged
_c.mutation.SetTotalRecharged(v)
}
return nil
}
@@ -503,6 +589,18 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
return &ValidationError{Name: "balance_notify_threshold_type", err: errors.New(`ent: missing required field "User.balance_notify_threshold_type"`)}
}
if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)}
}
if _, ok := _c.mutation.TotalRecharged(); !ok {
return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
}
return nil
}
@@ -586,6 +684,26 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
}
if value, ok := _c.mutation.BalanceNotifyThresholdType(); ok {
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
_node.BalanceNotifyThresholdType = value
}
if value, ok := _c.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
_node.BalanceNotifyThreshold = &value
}
if value, ok := _c.mutation.BalanceNotifyExtraEmails(); ok {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
_node.BalanceNotifyExtraEmails = value
}
if value, ok := _c.mutation.TotalRecharged(); ok {
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
_node.TotalRecharged = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -988,6 +1106,84 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v)
return u
}
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert {
u.SetExcluded(user.FieldBalanceNotifyEnabled)
return u
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (u *UserUpsert) SetBalanceNotifyThresholdType(v string) *UserUpsert {
u.Set(user.FieldBalanceNotifyThresholdType, v)
return u
}
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
func (u *UserUpsert) UpdateBalanceNotifyThresholdType() *UserUpsert {
u.SetExcluded(user.FieldBalanceNotifyThresholdType)
return u
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert {
u.Set(user.FieldBalanceNotifyThreshold, v)
return u
}
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
func (u *UserUpsert) UpdateBalanceNotifyThreshold() *UserUpsert {
u.SetExcluded(user.FieldBalanceNotifyThreshold)
return u
}
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
func (u *UserUpsert) AddBalanceNotifyThreshold(v float64) *UserUpsert {
u.Add(user.FieldBalanceNotifyThreshold, v)
return u
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (u *UserUpsert) ClearBalanceNotifyThreshold() *UserUpsert {
u.SetNull(user.FieldBalanceNotifyThreshold)
return u
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (u *UserUpsert) SetBalanceNotifyExtraEmails(v string) *UserUpsert {
u.Set(user.FieldBalanceNotifyExtraEmails, v)
return u
}
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert {
u.SetExcluded(user.FieldBalanceNotifyExtraEmails)
return u
}
// SetTotalRecharged sets the "total_recharged" field.
func (u *UserUpsert) SetTotalRecharged(v float64) *UserUpsert {
u.Set(user.FieldTotalRecharged, v)
return u
}
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotalRecharged() *UserUpsert {
u.SetExcluded(user.FieldTotalRecharged)
return u
}
// AddTotalRecharged adds v to the "total_recharged" field.
func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
u.Add(user.FieldTotalRecharged, v)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1250,6 +1446,97 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyEnabled(v)
})
}
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyEnabled()
})
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (u *UserUpsertOne) SetBalanceNotifyThresholdType(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyThresholdType(v)
})
}
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateBalanceNotifyThresholdType() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyThresholdType()
})
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyThreshold(v)
})
}
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
func (u *UserUpsertOne) AddBalanceNotifyThreshold(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddBalanceNotifyThreshold(v)
})
}
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateBalanceNotifyThreshold() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyThreshold()
})
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (u *UserUpsertOne) ClearBalanceNotifyThreshold() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearBalanceNotifyThreshold()
})
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (u *UserUpsertOne) SetBalanceNotifyExtraEmails(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyExtraEmails(v)
})
}
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyExtraEmails()
})
}
// SetTotalRecharged sets the "total_recharged" field.
func (u *UserUpsertOne) SetTotalRecharged(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotalRecharged(v)
})
}
// AddTotalRecharged adds v to the "total_recharged" field.
func (u *UserUpsertOne) AddTotalRecharged(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddTotalRecharged(v)
})
}
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotalRecharged()
})
}
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1678,6 +1965,97 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyEnabled(v)
})
}
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyEnabled()
})
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (u *UserUpsertBulk) SetBalanceNotifyThresholdType(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyThresholdType(v)
})
}
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateBalanceNotifyThresholdType() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyThresholdType()
})
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyThreshold(v)
})
}
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
func (u *UserUpsertBulk) AddBalanceNotifyThreshold(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddBalanceNotifyThreshold(v)
})
}
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateBalanceNotifyThreshold() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyThreshold()
})
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (u *UserUpsertBulk) ClearBalanceNotifyThreshold() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearBalanceNotifyThreshold()
})
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (u *UserUpsertBulk) SetBalanceNotifyExtraEmails(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyExtraEmails(v)
})
}
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyExtraEmails()
})
}
// SetTotalRecharged sets the "total_recharged" field.
func (u *UserUpsertBulk) SetTotalRecharged(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotalRecharged(v)
})
}
// AddTotalRecharged adds v to the "total_recharged" field.
func (u *UserUpsertBulk) AddTotalRecharged(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddTotalRecharged(v)
})
}
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotalRecharged()
})
}
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -243,6 +243,96 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v)
return _u
}
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate {
if v != nil {
_u.SetBalanceNotifyEnabled(*v)
}
return _u
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (_u *UserUpdate) SetBalanceNotifyThresholdType(v string) *UserUpdate {
_u.mutation.SetBalanceNotifyThresholdType(v)
return _u
}
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
func (_u *UserUpdate) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdate {
if v != nil {
_u.SetBalanceNotifyThresholdType(*v)
}
return _u
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate {
_u.mutation.ResetBalanceNotifyThreshold()
_u.mutation.SetBalanceNotifyThreshold(v)
return _u
}
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
func (_u *UserUpdate) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdate {
if v != nil {
_u.SetBalanceNotifyThreshold(*v)
}
return _u
}
// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
func (_u *UserUpdate) AddBalanceNotifyThreshold(v float64) *UserUpdate {
_u.mutation.AddBalanceNotifyThreshold(v)
return _u
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (_u *UserUpdate) ClearBalanceNotifyThreshold() *UserUpdate {
_u.mutation.ClearBalanceNotifyThreshold()
return _u
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (_u *UserUpdate) SetBalanceNotifyExtraEmails(v string) *UserUpdate {
_u.mutation.SetBalanceNotifyExtraEmails(v)
return _u
}
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate {
if v != nil {
_u.SetBalanceNotifyExtraEmails(*v)
}
return _u
}
// SetTotalRecharged sets the "total_recharged" field.
func (_u *UserUpdate) SetTotalRecharged(v float64) *UserUpdate {
_u.mutation.ResetTotalRecharged()
_u.mutation.SetTotalRecharged(v)
return _u
}
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotalRecharged(v *float64) *UserUpdate {
if v != nil {
_u.SetTotalRecharged(*v)
}
return _u
}
// AddTotalRecharged adds value to the "total_recharged" field.
func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
_u.mutation.AddTotalRecharged(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -746,6 +836,30 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
}
if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
_spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
if _u.mutation.BalanceNotifyThresholdCleared() {
_spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
}
if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
}
if value, ok := _u.mutation.TotalRecharged(); ok {
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1434,6 +1548,96 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v)
return _u
}
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne {
if v != nil {
_u.SetBalanceNotifyEnabled(*v)
}
return _u
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (_u *UserUpdateOne) SetBalanceNotifyThresholdType(v string) *UserUpdateOne {
_u.mutation.SetBalanceNotifyThresholdType(v)
return _u
}
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdateOne {
if v != nil {
_u.SetBalanceNotifyThresholdType(*v)
}
return _u
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne {
_u.mutation.ResetBalanceNotifyThreshold()
_u.mutation.SetBalanceNotifyThreshold(v)
return _u
}
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdateOne {
if v != nil {
_u.SetBalanceNotifyThreshold(*v)
}
return _u
}
// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
func (_u *UserUpdateOne) AddBalanceNotifyThreshold(v float64) *UserUpdateOne {
_u.mutation.AddBalanceNotifyThreshold(v)
return _u
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (_u *UserUpdateOne) ClearBalanceNotifyThreshold() *UserUpdateOne {
_u.mutation.ClearBalanceNotifyThreshold()
return _u
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (_u *UserUpdateOne) SetBalanceNotifyExtraEmails(v string) *UserUpdateOne {
_u.mutation.SetBalanceNotifyExtraEmails(v)
return _u
}
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdateOne {
if v != nil {
_u.SetBalanceNotifyExtraEmails(*v)
}
return _u
}
// SetTotalRecharged sets the "total_recharged" field.
func (_u *UserUpdateOne) SetTotalRecharged(v float64) *UserUpdateOne {
_u.mutation.ResetTotalRecharged()
_u.mutation.SetTotalRecharged(v)
return _u
}
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotalRecharged(v *float64) *UserUpdateOne {
if v != nil {
_u.SetTotalRecharged(*v)
}
return _u
}
// AddTotalRecharged adds value to the "total_recharged" field.
func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
_u.mutation.AddTotalRecharged(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1967,6 +2171,30 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
}
if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
_spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
if _u.mutation.BalanceNotifyThresholdCleared() {
_spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
}
if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
}
if value, ok := _u.mutation.TotalRecharged(); ok {
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -183,6 +183,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -218,6 +220,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -251,6 +255,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -280,6 +286,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -312,6 +320,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@@ -28,7 +28,7 @@ const (
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// UMQ用户消息队列模式常量
const (

View File

@@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
configPath := filepath.Join(tempDir, "config.yaml")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644))
yamlSafePath := filepath.ToSlash(templatePath)
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+yamlSafePath+"\"\n"), 0o644))
t.Setenv("DATA_DIR", tempDir)
cfg, err := Load()
require.NoError(t, err)
require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, yamlSafePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
}

View File

@@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}

View File

@@ -1,6 +1,7 @@
package admin
import (
"fmt"
"strconv"
"strings"
@@ -26,24 +27,32 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
// --- Request / Response types ---
type createChannelRequest struct {
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"`
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
}
type updateChannelRequest struct {
Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"`
Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"`
Features *string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
}
type channelModelPricingRequest struct {
@@ -71,18 +80,29 @@ type pricingIntervalRequest struct {
SortOrder int `json:"sort_order"`
}
type accountStatsPricingRuleRequest struct {
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
AccountIDs []int64 `json:"account_ids"`
Pricing []channelModelPricingRequest `json:"pricing"`
}
type channelResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type channelModelPricingResponse struct {
@@ -112,6 +132,14 @@ type pricingIntervalResponse struct {
SortOrder int `json:"sort_order"`
}
type accountStatsPricingRuleResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
AccountIDs []int64 `json:"account_ids"`
Pricing []channelModelPricingResponse `json:"pricing"`
}
func channelToResponse(ch *service.Channel) *channelResponse {
if ch == nil {
return nil
@@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Description: ch.Description,
Status: ch.Status,
RestrictModels: ch.RestrictModels,
Features: ch.Features,
FeaturesConfig: ch.FeaturesConfig,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
@@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
for _, p := range ch.ModelPricing {
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
}
resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
for _, rule := range ch.AccountStatsPricingRules {
ruleResp := accountStatsPricingRuleResponse{
ID: rule.ID,
Name: rule.Name,
GroupIDs: rule.GroupIDs,
AccountIDs: rule.AccountIDs,
Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
}
if ruleResp.GroupIDs == nil {
ruleResp.GroupIDs = []int64{}
}
if ruleResp.AccountIDs == nil {
ruleResp.AccountIDs = []int64{}
}
for i := range rule.Pricing {
ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
}
resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
}
return resp
}
@@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
billingMode = service.BillingModeToken
}
platform := r.Platform
if platform == "" {
platform = service.PlatformAnthropic
}
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals {
intervals = append(intervals, service.PricingInterval{
@@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result
}
func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
return service.AccountStatsPricingRule{
Name: r.Name,
GroupIDs: r.GroupIDs,
AccountIDs: r.AccountIDs,
Pricing: pricingRequestToService(r.Pricing),
}
}
// --- Handlers ---
// List handles listing channels with pagination
@@ -291,15 +350,42 @@ func (h *ChannelHandler) Create(c *gin.Context) {
}
pricing := pricingRequestToService(req.ModelPricing)
// Main model_pricing requires a platform; default to anthropic for backward compatibility.
for i := range pricing {
if pricing[i].Platform == "" {
pricing[i].Platform = service.PlatformAnthropic
}
}
var statsRules []service.AccountStatsPricingRule
for i, r := range req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name,
Description: req.Description,
GroupIDs: req.GroupIDs,
ModelPricing: pricing,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
Name: req.Name,
Description: req.Description,
GroupIDs: req.GroupIDs,
ModelPricing: pricing,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
Features: req.Features,
FeaturesConfig: req.FeaturesConfig,
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
AccountStatsPricingRules: statsRules,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -325,18 +411,45 @@ func (h *ChannelHandler) Update(c *gin.Context) {
}
input := &service.UpdateChannelInput{
Name: req.Name,
Description: req.Description,
Status: req.Status,
GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
Name: req.Name,
Description: req.Description,
Status: req.Status,
GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
Features: req.Features,
FeaturesConfig: req.FeaturesConfig,
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
for i := range pricing {
if pricing[i].Platform == "" {
pricing[i].Platform = service.PlatformAnthropic
}
}
input.ModelPricing = &pricing
}
if req.AccountStatsPricingRules != nil {
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
for i, r := range *req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
}
input.AccountStatsPricingRules = &statsRules
}
channel, err := h.channelService.Update(c.Request.Context(), id, input)
if err != nil {

View File

@@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) {
wantValue: string(service.BillingModeToken),
},
{
name: "empty platform defaults to anthropic",
name: "empty platform stays empty",
req: channelModelPricingRequest{
Models: []string{"m1"},
Platform: "",
},
wantField: "Platform",
wantValue: "anthropic",
wantValue: "",
},
}

View File

@@ -5,11 +5,10 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"log"
"log/slog"
"net/http"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
@@ -183,6 +188,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
PaymentEnabledTypes: paymentCfg.EnabledTypes,
PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
@@ -304,20 +311,29 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
// Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"`
PaymentMaxAmount *float64 `json:"payment_max_amount"`
PaymentDailyLimit *float64 `json:"payment_daily_limit"`
PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"`
PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"`
PaymentEnabledTypes []string `json:"payment_enabled_types"`
PaymentBalanceDisabled *bool `json:"payment_balance_disabled"`
PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"`
PaymentProductNamePrefix *string `json:"payment_product_name_prefix"`
PaymentProductNameSuffix *string `json:"payment_product_name_suffix"`
PaymentHelpImageURL *string `json:"payment_help_image_url"`
PaymentHelpText *string `json:"payment_help_text"`
PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"`
PaymentMaxAmount *float64 `json:"payment_max_amount"`
PaymentDailyLimit *float64 `json:"payment_daily_limit"`
PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"`
PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"`
PaymentEnabledTypes []string `json:"payment_enabled_types"`
PaymentBalanceDisabled *bool `json:"payment_balance_disabled"`
PaymentBalanceRechargeMultiplier *float64 `json:"payment_balance_recharge_multiplier"`
PaymentRechargeFeeRate *float64 `json:"payment_recharge_fee_rate"`
PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"`
PaymentProductNamePrefix *string `json:"payment_product_name_prefix"`
PaymentProductNameSuffix *string `json:"payment_product_name_suffix"`
PaymentHelpImageURL *string `json:"payment_help_image_url"`
PaymentHelpText *string `json:"payment_help_text"`
// Cancel rate limit
PaymentCancelRateLimitEnabled *bool `json:"payment_cancel_rate_limit_enabled"`
@@ -881,6 +897,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
}
return previousSettings.BalanceLowNotifyEnabled
}(),
BalanceLowNotifyThreshold: func() float64 {
if req.BalanceLowNotifyThreshold != nil {
return *req.BalanceLowNotifyThreshold
}
return previousSettings.BalanceLowNotifyThreshold
}(),
BalanceLowNotifyRechargeURL: func() string {
if req.BalanceLowNotifyRechargeURL != nil {
return *req.BalanceLowNotifyRechargeURL
}
return previousSettings.BalanceLowNotifyRechargeURL
}(),
AccountQuotaNotifyEnabled: func() bool {
if req.AccountQuotaNotifyEnabled != nil {
return *req.AccountQuotaNotifyEnabled
}
return previousSettings.AccountQuotaNotifyEnabled
}(),
AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry {
if req.AccountQuotaNotifyEmails != nil {
return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails)
}
return previousSettings.AccountQuotaNotifyEmails
}(),
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -892,24 +938,26 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
// Skip if no payment fields were provided (prevents accidental wipe).
if h.paymentConfigService != nil && hasPaymentFields(req) {
paymentReq := service.UpdatePaymentConfigRequest{
Enabled: req.PaymentEnabled,
MinAmount: req.PaymentMinAmount,
MaxAmount: req.PaymentMaxAmount,
DailyLimit: req.PaymentDailyLimit,
OrderTimeoutMin: req.PaymentOrderTimeoutMin,
MaxPendingOrders: req.PaymentMaxPendingOrders,
EnabledTypes: req.PaymentEnabledTypes,
BalanceDisabled: req.PaymentBalanceDisabled,
LoadBalanceStrategy: req.PaymentLoadBalanceStrat,
ProductNamePrefix: req.PaymentProductNamePrefix,
ProductNameSuffix: req.PaymentProductNameSuffix,
HelpImageURL: req.PaymentHelpImageURL,
HelpText: req.PaymentHelpText,
CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled,
CancelRateLimitMax: req.PaymentCancelRateLimitMax,
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
Enabled: req.PaymentEnabled,
MinAmount: req.PaymentMinAmount,
MaxAmount: req.PaymentMaxAmount,
DailyLimit: req.PaymentDailyLimit,
OrderTimeoutMin: req.PaymentOrderTimeoutMin,
MaxPendingOrders: req.PaymentMaxPendingOrders,
EnabledTypes: req.PaymentEnabledTypes,
BalanceDisabled: req.PaymentBalanceDisabled,
BalanceRechargeMultiplier: req.PaymentBalanceRechargeMultiplier,
RechargeFeeRate: req.PaymentRechargeFeeRate,
LoadBalanceStrategy: req.PaymentLoadBalanceStrat,
ProductNamePrefix: req.PaymentProductNamePrefix,
ProductNameSuffix: req.PaymentProductNameSuffix,
HelpImageURL: req.PaymentHelpImageURL,
HelpText: req.PaymentHelpText,
CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled,
CancelRateLimitMax: req.PaymentCancelRateLimitMax,
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
}
if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
response.ErrorFrom(c, err)
@@ -1027,6 +1075,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
@@ -1035,6 +1088,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
@@ -1054,6 +1109,7 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
req.PaymentOrderTimeoutMin != nil || req.PaymentMaxPendingOrders != nil ||
req.PaymentEnabledTypes != nil || req.PaymentBalanceDisabled != nil ||
req.PaymentBalanceRechargeMultiplier != nil || req.PaymentRechargeFeeRate != nil ||
req.PaymentLoadBalanceStrat != nil || req.PaymentProductNamePrefix != nil ||
req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
@@ -1073,11 +1129,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
time.Now().UTC().Format(time.RFC3339),
subject.UserID,
role,
changed,
slog.Info("settings updated",
"audit", true,
"user_id", subject.UserID,
"role", role,
"changed", changed,
)
}
@@ -1092,6 +1148,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist")
}
if before.PromoCodeEnabled != after.PromoCodeEnabled {
changed = append(changed, "promo_code_enabled")
}
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
changed = append(changed, "invitation_code_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
@@ -1302,6 +1364,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items")
}
if before.CustomEndpoints != after.CustomEndpoints {
changed = append(changed, "custom_endpoints")
}
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
changed = append(changed, "enable_fingerprint_unification")
}
@@ -1311,6 +1376,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
}
if before.BalanceLowNotifyThreshold != after.BalanceLowNotifyThreshold {
changed = append(changed, "balance_low_notify_threshold")
}
if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
changed = append(changed, "balance_low_notify_recharge_url")
}
if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
changed = append(changed, "account_quota_notify_enabled")
}
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
return changed
}
@@ -1367,6 +1448,18 @@ func equalIntSlice(a, b []int) bool {
return true
}
func equalNotifyEmailEntries(a, b []service.NotifyEmailEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Email != b[i].Email || a[i].Verified != b[i].Verified || a[i].Disabled != b[i].Disabled {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host"`
@@ -1847,3 +1940,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}
// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
// GET /api/v1/admin/settings/web-search-emulation
func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg))
}
// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
// PUT /api/v1/admin/settings/web-search-emulation
func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
var cfg service.WebSearchEmulationConfig
if err := c.ShouldBindJSON(&cfg); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
response.ErrorFrom(c, err)
return
}
// Re-read (with sanitized api keys) to return current state
updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated))
}
// ResetWebSearchUsage 重置指定 provider 的配额用量
// POST /api/v1/admin/settings/web-search-emulation/reset-usage
func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) {
var req struct {
ProviderType string `json:"provider_type"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.ProviderType == "" {
response.BadRequest(c, "provider_type is required")
return
}
if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, nil)
}
// TestWebSearchEmulation 测试 Web Search 搜索
// POST /api/v1/admin/settings/web-search-emulation/test
func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
var req struct {
Query string `json:"query"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.Query) == "" {
req.Query = "搜索今年世界大事件"
}
result, err := service.TestWebSearch(c.Request.Context(), req.Query)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}

View File

@@ -13,16 +13,21 @@ func UserFromServiceShallow(u *service.User) *User {
return nil
}
return &User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
TotalRecharged: u.TotalRecharged,
}
}
@@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account {
out.QuotaWeeklyResetAt = &v
}
}
// 配额通知配置
if enabled := a.GetQuotaNotifyDailyEnabled(); enabled {
out.QuotaNotifyDailyEnabled = &enabled
}
if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 {
out.QuotaNotifyDailyThreshold = &threshold
}
if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled {
out.QuotaNotifyWeeklyEnabled = &enabled
}
if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 {
out.QuotaNotifyWeeklyThreshold = &threshold
}
if enabled := a.GetQuotaNotifyTotalEnabled(); enabled {
out.QuotaNotifyTotalEnabled = &enabled
}
if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 {
out.QuotaNotifyTotalThreshold = &threshold
}
}
return out
@@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
ModelMappingChain: l.ModelMappingChain,
BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier,
AccountStatsCost: l.AccountStatsCost,
IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account),
}

View File

@@ -0,0 +1,43 @@
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct {
Email string `json:"email"`
Disabled bool `json:"disabled"`
Verified bool `json:"verified"`
}
// NotifyEmailEntriesFromService converts service entries to DTO entries.
func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry {
if entries == nil {
return nil
}
result := make([]NotifyEmailEntry, len(entries))
for i, e := range entries {
result[i] = NotifyEmailEntry{
Email: e.Email,
Disabled: e.Disabled,
Verified: e.Verified,
}
}
return result
}
// NotifyEmailEntriesToService converts DTO entries to service entries.
func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry {
if entries == nil {
return nil
}
result := make([]service.NotifyEmailEntry, len(entries))
for i, e := range entries {
result[i] = service.NotifyEmailEntry{
Email: e.Email,
Disabled: e.Disabled,
Verified: e.Verified,
}
}
return result
}

View File

@@ -124,20 +124,25 @@ type SystemSettings struct {
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
PaymentMaxAmount float64 `json:"payment_max_amount"`
PaymentDailyLimit float64 `json:"payment_daily_limit"`
PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"`
PaymentMaxPendingOrders int `json:"payment_max_pending_orders"`
PaymentEnabledTypes []string `json:"payment_enabled_types"`
PaymentBalanceDisabled bool `json:"payment_balance_disabled"`
PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"`
PaymentProductNamePrefix string `json:"payment_product_name_prefix"`
PaymentProductNameSuffix string `json:"payment_product_name_suffix"`
PaymentHelpImageURL string `json:"payment_help_image_url"`
PaymentHelpText string `json:"payment_help_text"`
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
PaymentMaxAmount float64 `json:"payment_max_amount"`
PaymentDailyLimit float64 `json:"payment_daily_limit"`
PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"`
PaymentMaxPendingOrders int `json:"payment_max_pending_orders"`
PaymentEnabledTypes []string `json:"payment_enabled_types"`
PaymentBalanceDisabled bool `json:"payment_balance_disabled"`
PaymentBalanceRechargeMultiplier float64 `json:"payment_balance_recharge_multiplier"`
PaymentRechargeFeeRate float64 `json:"payment_recharge_fee_rate"`
PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"`
PaymentProductNamePrefix string `json:"payment_product_name_prefix"`
PaymentProductNameSuffix string `json:"payment_product_name_suffix"`
PaymentHelpImageURL string `json:"payment_help_image_url"`
PaymentHelpText string `json:"payment_help_text"`
// Cancel rate limit
PaymentCancelRateLimitEnabled bool `json:"payment_cancel_rate_limit_enabled"`
@@ -145,6 +150,13 @@ type SystemSettings struct {
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
// Balance low notification
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
}
type DefaultSubscriptionSetting struct {
@@ -183,6 +195,10 @@ type PublicSettings struct {
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO

View File

@@ -18,6 +18,13 @@ type User struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
TotalRecharged float64 `json:"total_recharged"`
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
@@ -218,6 +225,14 @@ type Account struct {
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
// 配额通知配置
QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"`
QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"`
QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"`
QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"`
QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"`
QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -412,6 +427,8 @@ type AdminUsageLog struct {
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// AccountStatsCost 自定义定价规则计算的账号统计费用nil 表示使用默认公式)
AccountStatsCost *float64 `json:"account_stats_cost,omitempty"`
// IPAddress 用户请求 IP仅管理员可见
IPAddress *string `json:"ip_address,omitempty"`

View File

@@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 设置请求所属分组 ID用于渠道级功能判断如 WebSearch 模拟)
parsedReq.GroupID = apiKey.GroupID
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
@@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
APIKey: apiKey,
User: apiKey.User,
Account: account,
@@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0))
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
@@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,

View File

@@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
nil, // balanceNotifyService
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。

View File

@@ -126,26 +126,30 @@ func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) {
}
response.Success(c, checkoutInfoResponse{
Methods: limitsResp.Methods,
GlobalMin: limitsResp.GlobalMin,
GlobalMax: limitsResp.GlobalMax,
Plans: planList,
BalanceDisabled: cfg.BalanceDisabled,
HelpText: cfg.HelpText,
HelpImageURL: cfg.HelpImageURL,
StripePublishableKey: cfg.StripePublishableKey,
Methods: limitsResp.Methods,
GlobalMin: limitsResp.GlobalMin,
GlobalMax: limitsResp.GlobalMax,
Plans: planList,
BalanceDisabled: cfg.BalanceDisabled,
BalanceRechargeMultiplier: cfg.BalanceRechargeMultiplier,
RechargeFeeRate: cfg.RechargeFeeRate,
HelpText: cfg.HelpText,
HelpImageURL: cfg.HelpImageURL,
StripePublishableKey: cfg.StripePublishableKey,
})
}
type checkoutInfoResponse struct {
Methods map[string]service.MethodLimits `json:"methods"`
GlobalMin float64 `json:"global_min"`
GlobalMax float64 `json:"global_max"`
Plans []checkoutPlan `json:"plans"`
BalanceDisabled bool `json:"balance_disabled"`
HelpText string `json:"help_text"`
HelpImageURL string `json:"help_image_url"`
StripePublishableKey string `json:"stripe_publishable_key"`
Methods map[string]service.MethodLimits `json:"methods"`
GlobalMin float64 `json:"global_min"`
GlobalMax float64 `json:"global_max"`
Plans []checkoutPlan `json:"plans"`
BalanceDisabled bool `json:"balance_disabled"`
BalanceRechargeMultiplier float64 `json:"balance_recharge_multiplier"`
RechargeFeeRate float64 `json:"recharge_fee_rate"`
HelpText string `json:"help_text"`
HelpImageURL string `json:"help_image_url"`
StripePublishableKey string `json:"stripe_publishable_key"`
}
type checkoutPlan struct {
@@ -335,6 +339,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) {
response.Success(c, gin.H{"message": "refund requested"})
}
// GetRefundEligibleProviders returns provider instance IDs that allow user refund.
func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) {
ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"provider_instance_ids": ids})
}
// VerifyOrderRequest is the request body for verifying a payment order.
type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"`
@@ -371,6 +385,7 @@ type PublicOrderResult struct {
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"`
}
@@ -394,6 +409,7 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
Amount: order.Amount,
PayAmount: order.PayAmount,
PaymentType: order.PaymentType,
OrderType: order.OrderType,
Status: order.Status,
})
}

View File

@@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
Version: h.version,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
})
}

View File

@@ -11,13 +11,17 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
userService *service.UserService
emailService *service.EmailService
emailCache service.EmailCache
}
// NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService) *UserHandler {
func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
return &UserHandler{
userService: userService,
userService: userService,
emailService: emailService,
emailCache: emailCache,
}
}
@@ -29,7 +33,9 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
Username *string `json:"username"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// GetProfile handles getting user profile
@@ -94,7 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
}
svcReq := service.UpdateProfileRequest{
Username: req.Username,
Username: req.Username,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
@@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, dto.UserFromService(updatedUser))
}
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
type SendNotifyEmailCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
// SendNotifyEmailCode sends verification code to extra notification email
// POST /api/v1/user/notify-email/send-code
func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req SendNotifyEmailCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Verification code sent successfully"})
}
// VerifyNotifyEmailRequest represents the request to verify and add notify email
type VerifyNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
Code string `json:"code" binding:"required,len=6"`
}
// VerifyNotifyEmail verifies code and adds email to notification list
// POST /api/v1/user/notify-email/verify
func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req VerifyNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated user
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
type RemoveNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
}
// RemoveNotifyEmail removes email from notification list
// DELETE /api/v1/user/notify-email
func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req RemoveNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated user
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
type ToggleNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
Disabled bool `json:"disabled"`
}
// ToggleNotifyEmail toggles the disabled state of a notification email
// PUT /api/v1/user/notify-email/toggle
func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req ToggleNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled)
if err != nil {
response.ErrorFrom(c, err)
return
}
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
var matched []*dbent.PaymentProviderInstance
for _, inst := range instances {
if InstanceSupportsType(inst.SupportedTypes, paymentType) {
// Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
if paymentType == TypeStripe {
if inst.ProviderKey == TypeStripe {
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
matched = append(matched, inst)
}
}

View File

@@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) {
wantIDs: nil,
},
{
name: "empty candidates returns empty",
name: "empty candidates returns empty",
candidates: nil,
paymentType: "alipay",
orderAmount: 10,

View File

@@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) {
errSubstr: "privateKey",
},
{
name: "nil config map returns error for appId",
config: map[string]string{},
wantErr: true,
name: "nil config map returns error for appId",
config: map[string]string{},
wantErr: true,
errSubstr: "appId",
},
}

View File

@@ -18,6 +18,9 @@ const (
BlockTypeFunction
)
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
type UsageMapHook func(usageMap map[string]any)
// StreamingProcessor 流式响应处理器
type StreamingProcessor struct {
blockType BlockType
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
usageMapHook UsageMapHook
// 累计 usage
inputTokens int
@@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
}
}
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
if u.ImageOutputTokens > 0 {
m["image_output_tokens"] = u.ImageOutputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
@@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID()
}
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
message := map[string]any{
"id": responseID,
"type": "message",
@@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel,
"stop_reason": nil,
"stop_sequence": nil,
"usage": usage,
"usage": usageValue,
}
event := map[string]any{
@@ -492,13 +525,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
ImageOutputTokens: p.imageOutputTokens,
}
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
deltaEvent := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usage,
"usage": usageValue,
}
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))

View File

@@ -27,13 +27,14 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
}
out := &ResponsesRequest{
Model: req.Model,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: true, // upstream always streams
Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
Model: req.Model,
Instructions: req.Instructions,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: true, // upstream always streams
Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
}
storeFalse := false

View File

@@ -152,6 +152,7 @@ type AnthropicDelta struct {
// ResponsesRequest is the request body for POST /v1/responses.
type ResponsesRequest struct {
Model string `json:"model"`
Instructions string `json:"instructions,omitempty"`
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
@@ -337,6 +338,7 @@ type ResponsesStreamEvent struct {
type ChatCompletionsRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Instructions string `json:"instructions,omitempty"` // OpenAI Responses API compat
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`

View File

@@ -10,7 +10,13 @@ import (
)
func TestInit_DualOutput(t *testing.T) {
tmpDir := t.TempDir()
// Use os.MkdirTemp instead of t.TempDir to avoid cleanup failures
// when lumberjack holds file handles on Windows.
tmpDir, err := os.MkdirTemp("", "logger-test-*")
if err != nil {
t.Fatalf("create temp dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
origStdout := os.Stdout
@@ -57,7 +63,9 @@ func TestInit_DualOutput(t *testing.T) {
L().Info("dual-output-info")
L().Warn("dual-output-warn")
Sync()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
// The log data is already in the pipe buffer; closing writers is sufficient.
_ = stdoutW.Close()
_ = stderrW.Close()
@@ -166,7 +174,9 @@ func TestInit_CallerShouldPointToCallsite(t *testing.T) {
}
L().Info("caller-check")
Sync()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutW.Close()
logBytes, _ := io.ReadAll(stdoutR)

View File

@@ -77,7 +77,7 @@ func TestStdLogBridgeRoutesLevels(t *testing.T) {
log.Printf("service started")
log.Printf("Warning: queue full")
log.Printf("Forward request failed: timeout")
Sync()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_ = stdoutW.Close()
_ = stderrW.Close()
@@ -139,7 +139,7 @@ func TestLegacyPrintfRoutesLevels(t *testing.T) {
LegacyPrintf("service.test", "request started")
LegacyPrintf("service.test", "Warning: queue full")
LegacyPrintf("service.test", "forward failed: timeout")
Sync()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_ = stdoutW.Close()
_ = stderrW.Close()

View File

@@ -56,8 +56,9 @@ type DashboardStats struct {
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
TotalAccountCost float64 `json:"total_account_cost"` // 累计账号成本
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
@@ -66,8 +67,9 @@ type DashboardStats struct {
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
TodayAccountCost float64 `json:"today_account_cost"` // 今日账号成本
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
@@ -99,8 +101,9 @@ type ModelStat struct {
CacheCreationTokens int64 `json:"cache_creation_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
AccountCost float64 `json:"account_cost"` // 账号成本
}
// EndpointStat represents usage statistics for a single request endpoint.
@@ -125,8 +128,9 @@ type GroupStat struct {
GroupName string `json:"group_name"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
AccountCost float64 `json:"account_cost"` // 账号成本
}
// UserUsageTrendPoint represents user usage trend data point
@@ -164,8 +168,9 @@ type UserBreakdownItem struct {
Email string `json:"email"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
AccountCost float64 `json:"account_cost"` // 账号成本
}
// UserBreakdownDimension specifies the dimension to filter for user breakdown.

View File

@@ -0,0 +1,106 @@
package websearch
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
)
const (
braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
braveMaxCount = 20
braveProviderName = "brave"
)
// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
// BraveProvider implements web search via the Brave Search API.
type BraveProvider struct {
apiKey string
httpClient *http.Client
}
// NewBraveProvider creates a Brave Search provider.
// The caller is responsible for configuring the http.Client with proxy/timeouts.
func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
}
func (b *BraveProvider) Name() string { return braveProviderName }
func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
count := req.MaxResults
if count <= 0 {
count = defaultMaxResults
}
if count > braveMaxCount {
count = braveMaxCount
}
u := *braveSearchURL // copy the pre-parsed URL
q := u.Query()
q.Set("q", req.Query)
q.Set("count", strconv.Itoa(count))
u.RawQuery = q.Encode()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("brave: build request: %w", err)
}
httpReq.Header.Set("X-Subscription-Token", b.apiKey)
httpReq.Header.Set("Accept", "application/json")
resp, err := b.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("brave: request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
if err != nil {
return nil, fmt.Errorf("brave: read body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
}
var raw braveResponse
if err := json.Unmarshal(body, &raw); err != nil {
return nil, fmt.Errorf("brave: decode response: %w", err)
}
results := make([]SearchResult, 0, len(raw.Web.Results))
for _, r := range raw.Web.Results {
results = append(results, SearchResult{
URL: r.URL,
Title: r.Title,
Snippet: r.Description,
PageAge: r.Age,
})
}
return &SearchResponse{Results: results, Query: req.Query}, nil
}
// braveResponse is the minimal structure of the Brave Search API response.
type braveResponse struct {
Web struct {
Results []braveResult `json:"results"`
} `json:"web"`
}
type braveResult struct {
URL string `json:"url"`
Title string `json:"title"`
Description string `json:"description"`
Age string `json:"age"`
}

View File

@@ -0,0 +1,119 @@
package websearch
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestBraveProvider_Name(t *testing.T) {
p := NewBraveProvider("key", nil)
require.Equal(t, "brave", p.Name())
}
func TestBraveProvider_Search_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
require.Equal(t, "application/json", r.Header.Get("Accept"))
require.Equal(t, "golang", r.URL.Query().Get("q"))
require.Equal(t, "3", r.URL.Query().Get("count"))
resp := braveResponse{}
resp.Web.Results = []braveResult{
{URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
{URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
{URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
p := NewBraveProvider("test-key", srv.Client())
// Override the endpoint for testing
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
require.NoError(t, err)
require.Len(t, resp.Results, 3)
require.Equal(t, "https://go.dev", resp.Results[0].URL)
require.Equal(t, "Go lang", resp.Results[0].Snippet)
require.Equal(t, "1 day", resp.Results[0].PageAge)
}
func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
var receivedCount string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedCount = r.URL.Query().Get("count")
resp := braveResponse{}
_ = json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
p := NewBraveProvider("key", srv.Client())
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
_, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
require.Equal(t, "5", receivedCount)
}
func TestBraveProvider_Search_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(429)
_, _ = w.Write([]byte("rate limited"))
}))
defer srv.Close()
p := NewBraveProvider("key", srv.Client())
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
require.ErrorContains(t, err, "brave: status 429")
}
func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("not json"))
}))
defer srv.Close()
p := NewBraveProvider("key", srv.Client())
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
require.ErrorContains(t, err, "brave: decode response")
}
func TestBraveProvider_Search_EmptyResults(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
resp := braveResponse{}
_ = json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
p := NewBraveProvider("key", srv.Client())
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
require.NoError(t, err)
require.Empty(t, resp.Results)
}

View File

@@ -0,0 +1,14 @@
package websearch
const (
maxResponseSize = 1 << 20 // 1 MB
errorBodyTruncLen = 200
)
// truncateBody returns a truncated string of body for error messages.
func truncateBody(body []byte) string {
if len(body) <= errorBodyTruncLen {
return string(body)
}
return string(body[:errorBodyTruncLen]) + "...(truncated)"
}

View File

@@ -0,0 +1,25 @@
package websearch
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestTruncateBody_Short(t *testing.T) {
body := []byte("short body")
require.Equal(t, "short body", truncateBody(body))
}
func TestTruncateBody_Long(t *testing.T) {
body := []byte(strings.Repeat("x", 500))
result := truncateBody(body)
require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
require.True(t, strings.HasSuffix(result, "...(truncated)"))
}
func TestTruncateBody_ExactBoundary(t *testing.T) {
body := []byte(strings.Repeat("x", errorBodyTruncLen))
require.Equal(t, string(body), truncateBody(body))
}

View File

@@ -0,0 +1,528 @@
package websearch
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"math/rand"
"net"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/redis/go-redis/v9"
)
// ProviderConfig holds the configuration for a single search provider.
type ProviderConfig struct {
Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
APIKey string `json:"api_key"` // secret
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly from this date
ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
}
// Manager selects providers by quota-weighted load balancing and tracks quota via Redis.
type Manager struct {
configs []ProviderConfig
redis *redis.Client
clientMu sync.Mutex
clientCache map[string]*http.Client
}
// Timeout constants for proxy and search operations.
const (
proxyDialTimeout = 3 * time.Second // proxy TCP connection timeout
proxyTLSTimeout = 3 * time.Second // TLS handshake timeout
searchDataTimeout = 60 * time.Second // response data transfer timeout
searchRequestTimeout = searchDataTimeout + proxyDialTimeout
quotaKeyPrefix = "websearch:quota:"
proxyUnavailableKey = "websearch:proxy_unavailable:%d"
proxyUnavailableTTL = 5 * time.Minute
quotaTTLBuffer = 24 * time.Hour
defaultQuotaTTL = 31*24*time.Hour + quotaTTLBuffer // fallback when no subscription date
maxCachedClients = 100
)
// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
// Callers may use this to trigger account switching instead of direct fallback.
var ErrProxyUnavailable = errors.New("websearch: proxy unavailable")
// quotaIncrScript atomically increments the counter and sets TTL on first creation.
var quotaIncrScript = redis.NewScript(`
local val = redis.call('INCR', KEYS[1])
if val == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
else
local ttl = redis.call('TTL', KEYS[1])
if ttl == -1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
end
end
return val
`)
// NewManager creates a Manager with the given provider configs and Redis client.
// Provider order is preserved as-is; selectByQuotaWeight handles load balancing.
func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
copied := make([]ProviderConfig, len(configs))
copy(copied, configs)
return &Manager{
configs: copied,
redis: redisClient,
clientCache: make(map[string]*http.Client),
}
}
// SearchWithBestProvider selects a provider using quota-weighted load balancing,
// reserves quota, executes the search, and rolls back quota on failure.
// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes.
func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
if strings.TrimSpace(req.Query) == "" {
return nil, "", fmt.Errorf("websearch: empty search query")
}
candidates := m.filterAvailableProviders(ctx, req.ProxyURL)
if len(candidates) == 0 {
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted, expired, or proxy unavailable)")
}
selected := m.selectByQuotaWeight(ctx, candidates)
for _, cfg := range selected {
allowed, incremented := m.tryReserveQuota(ctx, cfg)
if !allowed {
continue
}
resp, err := m.executeSearch(ctx, cfg, req)
if err != nil {
if incremented {
m.rollbackQuota(ctx, cfg)
}
if isProxyError(err) {
m.markProxyUnavailable(ctx, cfg, req.ProxyURL)
if req.ProxyURL != "" {
// Account-level proxy is shared by all providers — no point
// trying others with the same broken proxy; signal account switch.
slog.Warn("websearch: account proxy error, aborting failover",
"provider", cfg.Type, "error", err)
return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
}
// Provider-specific proxy failed — try the next provider which
// may use a different (or no) proxy.
slog.Warn("websearch: provider proxy error, trying next provider",
"provider", cfg.Type, "error", err)
continue
}
slog.Warn("websearch: provider search failed",
"provider", cfg.Type, "error", err)
continue
}
return resp, cfg.Type, nil
}
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
}
// filterAvailableProviders returns providers that have API keys, are not expired,
// and whose proxies are not marked unavailable.
func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL string) []ProviderConfig {
var out []ProviderConfig
for _, cfg := range m.configs {
if !m.isProviderAvailable(cfg) {
continue
}
proxyID := resolveProxyID(cfg, accountProxyURL)
if proxyID > 0 && !m.isProxyAvailable(ctx, proxyID) {
slog.Debug("websearch: proxy marked unavailable, skipping",
"provider", cfg.Type, "proxy_id", proxyID)
continue
}
out = append(out, cfg)
}
return out
}
// weighted is a provider candidate with computed quota weight.
type weighted struct {
cfg ProviderConfig
weight int64
}
// selectByQuotaWeight orders candidates by remaining quota weight.
// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
// Among providers with quota, higher remaining quota = higher priority.
func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
items := m.computeWeights(ctx, candidates)
withQuota, withoutQuota := partitionByQuota(items)
sortByStableRandomWeight(withQuota)
return mergeWeightedResults(withQuota, withoutQuota, len(candidates))
}
func (m *Manager) computeWeights(ctx context.Context, candidates []ProviderConfig) []weighted {
items := make([]weighted, 0, len(candidates))
for _, cfg := range candidates {
w := int64(0)
if cfg.QuotaLimit > 0 {
used, _ := m.GetUsage(ctx, cfg.Type)
if remaining := cfg.QuotaLimit - used; remaining > 0 {
w = remaining
}
}
items = append(items, weighted{cfg: cfg, weight: w})
}
return items
}
func partitionByQuota(items []weighted) (withQuota, withoutQuota []weighted) {
for _, item := range items {
if item.weight > 0 {
withQuota = append(withQuota, item)
} else {
withoutQuota = append(withoutQuota, item)
}
}
return
}
// sortByStableRandomWeight assigns a fixed random factor to each item before sorting,
// ensuring deterministic sort behavior (transitivity) within a single call.
func sortByStableRandomWeight(items []weighted) {
if len(items) <= 1 {
return
}
type entry struct {
item weighted
factor float64
}
entries := make([]entry, len(items))
for i, item := range items {
entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())}
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].factor > entries[j].factor
})
for i, e := range entries {
items[i] = e.item
}
}
func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig {
result := make([]ProviderConfig, 0, capacity)
for _, item := range withQuota {
result = append(result, item.cfg)
}
for _, item := range withoutQuota {
result = append(result, item.cfg)
}
return result
}
func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
if cfg.APIKey == "" {
return false
}
if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
slog.Info("websearch: provider expired, skipping",
"provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
return false
}
return true
}
// --- Proxy availability tracking ---
// markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL.
func (m *Manager) markProxyUnavailable(ctx context.Context, cfg ProviderConfig, accountProxyURL string) {
proxyID := resolveProxyID(cfg, accountProxyURL)
if proxyID <= 0 || m.redis == nil {
return
}
key := fmt.Sprintf(proxyUnavailableKey, proxyID)
if err := m.redis.Set(ctx, key, "1", proxyUnavailableTTL).Err(); err != nil {
slog.Warn("websearch: failed to mark proxy unavailable",
"proxy_id", proxyID, "error", err)
}
}
// isProxyAvailable checks whether a proxy is currently marked as unavailable.
func (m *Manager) isProxyAvailable(ctx context.Context, proxyID int64) bool {
if m.redis == nil || proxyID <= 0 {
return true
}
key := fmt.Sprintf(proxyUnavailableKey, proxyID)
val, err := m.redis.Get(ctx, key).Result()
if err != nil {
return true // Redis error → assume available
}
return val == ""
}
// resolveProxyID determines the effective proxy ID for a provider+account combination.
func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 {
if accountProxyURL != "" {
return 0 // account proxy has no ID in provider config
}
return cfg.ProxyID
}
// isProxyError checks whether the error is likely caused by proxy or network connectivity
// (as opposed to an API-level error from the search provider).
func isProxyError(err error) bool {
if err == nil {
return false
}
// Network-level errors (timeout, connection refused, DNS failure)
var netErr net.Error
if errors.As(err, &netErr) {
return true
}
var opErr *net.OpError
if errors.As(err, &opErr) {
return true
}
// TLS handshake failures (often caused by proxy intercepting/blocking)
var tlsErr *tls.RecordHeaderError
if errors.As(err, &tlsErr) {
return true
}
// String-based detection for wrapped errors
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "proxy") ||
strings.Contains(msg, "socks") ||
strings.Contains(msg, "connection refused") ||
strings.Contains(msg, "no such host") ||
strings.Contains(msg, "i/o timeout") ||
strings.Contains(msg, "tls handshake") ||
strings.Contains(msg, "certificate")
}
// --- Quota management ---
func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
if cfg.QuotaLimit <= 0 {
return true, false
}
if m.redis == nil {
slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type)
return true, false
}
key := quotaRedisKey(cfg.Type)
ttlSec := int(quotaTTLFromSubscription(cfg.SubscribedAt).Seconds())
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
if err != nil {
slog.Warn("websearch: quota Lua INCR failed, allowing request",
"provider", cfg.Type, "error", err)
return true, false
}
if newVal > cfg.QuotaLimit {
if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
slog.Warn("websearch: quota over-limit DECR failed",
"provider", cfg.Type, "error", decrErr)
}
slog.Info("websearch: provider quota exhausted",
"provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
return false, false
}
return true, true
}
func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
if cfg.QuotaLimit <= 0 || m.redis == nil {
return
}
key := quotaRedisKey(cfg.Type)
if err := m.redis.Decr(ctx, key).Err(); err != nil {
slog.Warn("websearch: quota rollback DECR failed",
"provider", cfg.Type, "error", err)
}
}
// --- Search execution ---
// TestSearch executes a search using the first available provider without reserving quota.
// Intended for admin test functionality only.
func (m *Manager) TestSearch(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
if strings.TrimSpace(req.Query) == "" {
return nil, "", fmt.Errorf("websearch: empty search query")
}
for _, cfg := range m.configs {
if !m.isProviderAvailable(cfg) {
continue
}
resp, err := m.executeSearch(ctx, cfg, req)
if err != nil {
continue
}
return resp, cfg.Type, nil
}
return nil, "", fmt.Errorf("websearch: no available provider")
}
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
proxyURL := cfg.ProxyURL
if req.ProxyURL != "" {
proxyURL = req.ProxyURL
}
client, err := m.getOrCreateHTTPClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("websearch: %w", err)
}
provider := m.buildProvider(cfg, client)
return provider.Search(ctx, req)
}
// --- HTTP client cache ---
func (m *Manager) getOrCreateHTTPClient(proxyURL string) (*http.Client, error) {
m.clientMu.Lock()
defer m.clientMu.Unlock()
if c, ok := m.clientCache[proxyURL]; ok {
return c, nil
}
if len(m.clientCache) >= maxCachedClients {
m.clientCache = make(map[string]*http.Client)
}
c, err := newHTTPClient(proxyURL)
if err != nil {
return nil, err
}
m.clientCache[proxyURL] = c
return c, nil
}
// newHTTPClient creates an HTTP client with proper timeout settings.
// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support
// (HTTP/HTTPS/SOCKS5/SOCKS5H).
// Returns error if proxyURL is invalid — never falls back to direct connection.
func newHTTPClient(proxyURL string) (*http.Client, error) {
transport := &http.Transport{
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
DialContext: (&net.Dialer{Timeout: proxyDialTimeout}).DialContext,
TLSHandshakeTimeout: proxyTLSTimeout,
ResponseHeaderTimeout: searchDataTimeout,
}
if proxyURL != "" {
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err)
}
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, fmt.Errorf("configure proxy: %w", err)
}
}
return &http.Client{Transport: transport, Timeout: searchRequestTimeout}, nil
}
// GetUsage returns the current usage count for the given provider.
func (m *Manager) GetUsage(ctx context.Context, providerType string) (int64, error) {
if m.redis == nil {
return 0, nil
}
key := quotaRedisKey(providerType)
val, err := m.redis.Get(ctx, key).Int64()
if err == redis.Nil {
return 0, nil
}
return val, err
}
// GetAllUsage returns usage for every configured provider.
func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
result := make(map[string]int64, len(m.configs))
for _, cfg := range m.configs {
used, _ := m.GetUsage(ctx, cfg.Type)
result[cfg.Type] = used
}
return result
}
// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0.
func (m *Manager) ResetUsage(ctx context.Context, providerType string) error {
if m.redis == nil {
return nil
}
key := quotaRedisKey(providerType)
return m.redis.Del(ctx, key).Err()
}
// --- Provider factory ---
func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
switch cfg.Type {
case braveProviderName:
return NewBraveProvider(cfg.APIKey, client)
case tavilyProviderName:
return NewTavilyProvider(cfg.APIKey, client)
default:
slog.Warn("websearch: unknown provider type, falling back to brave",
"type", cfg.Type)
return NewBraveProvider(cfg.APIKey, client)
}
}
// --- Redis key helpers ---
func quotaRedisKey(providerType string) string {
return quotaKeyPrefix + providerType
}
// quotaTTLFromSubscription calculates the TTL for the quota counter based on
// the provider's subscription start date. Quota resets monthly from that date.
// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh).
func quotaTTLFromSubscription(subscribedAt *int64) time.Duration {
if subscribedAt == nil || *subscribedAt == 0 {
return defaultQuotaTTL
}
next := nextMonthlyReset(time.Unix(*subscribedAt, 0).UTC())
ttl := time.Until(next) + quotaTTLBuffer
if ttl <= quotaTTLBuffer {
// Already past the reset — next cycle
ttl = defaultQuotaTTL
}
return ttl
}
// nextMonthlyReset returns the next monthly reset time based on the subscription start date.
// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc.
// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3).
func nextMonthlyReset(subscribedAt time.Time) time.Time {
now := time.Now().UTC()
if subscribedAt.IsZero() {
return now.AddDate(0, 1, 0)
}
months := (now.Year()-subscribedAt.Year())*12 + int(now.Month()-subscribedAt.Month())
if months < 0 {
months = 0
}
candidate := addMonthsClamped(subscribedAt, months)
if candidate.After(now) {
return candidate
}
return addMonthsClamped(subscribedAt, months+1)
}
// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month.
// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3).
func addMonthsClamped(t time.Time, months int) time.Time {
y, m, d := t.Date()
targetMonth := time.Month(int(m) + months)
targetYear := y + int(targetMonth-1)/12
targetMonth = (targetMonth-1)%12 + 1
// Last day of the target month
lastDay := time.Date(targetYear, targetMonth+1, 0, 0, 0, 0, 0, time.UTC).Day()
if d > lastDay {
d = lastDay
}
return time.Date(targetYear, targetMonth, d, 0, 0, 0, 0, time.UTC)
}

View File

@@ -0,0 +1,323 @@
package websearch
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewManager_PreservesOrder(t *testing.T) {
configs := []ProviderConfig{
{Type: "brave", APIKey: "k3"},
{Type: "tavily", APIKey: "k1"},
}
m := NewManager(configs, nil)
require.Equal(t, "brave", m.configs[0].Type)
require.Equal(t, "tavily", m.configs[1].Type)
}
func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
require.ErrorContains(t, err, "empty search query")
_, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
require.ErrorContains(t, err, "empty search query")
}
func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
require.ErrorContains(t, err, "no available provider")
}
func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
past := time.Now().Add(-1 * time.Hour).Unix()
m := NewManager([]ProviderConfig{
{Type: "brave", APIKey: "k", ExpiresAt: &past},
}, nil)
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
require.ErrorContains(t, err, "no available provider")
}
func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) {
srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
resp := braveResponse{}
resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
_ = json.NewEncoder(w).Encode(resp)
}))
defer srvBrave.Close()
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srvBrave.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
m := NewManager([]ProviderConfig{
{Type: "brave", APIKey: "k1"},
{Type: "tavily", APIKey: "k2"},
}, nil)
m.clientCache[srvBrave.URL] = srvBrave.Client()
m.clientCache[""] = srvBrave.Client()
resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
require.NoError(t, err)
require.Equal(t, "brave", providerName)
require.Len(t, resp.Results, 1)
require.Equal(t, "from brave", resp.Results[0].Snippet)
}
func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
resp := braveResponse{}
resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
_ = json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
m := NewManager([]ProviderConfig{
{Type: "brave", APIKey: "k", QuotaLimit: 100},
}, nil)
m.clientCache[""] = srv.Client()
resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
require.NoError(t, err)
require.Len(t, resp.Results, 1)
}
func TestManager_GetUsage_NilRedis(t *testing.T) {
m := NewManager(nil, nil)
used, err := m.GetUsage(context.Background(), "brave")
require.NoError(t, err)
require.Equal(t, int64(0), used)
}
func TestManager_GetAllUsage_NilRedis(t *testing.T) {
m := NewManager([]ProviderConfig{
{Type: "brave"},
}, nil)
usage := m.GetAllUsage(context.Background())
require.Equal(t, int64(0), usage["brave"])
}
// --- Quota TTL from subscription ---
func TestQuotaTTLFromSubscription_NilSubscription(t *testing.T) {
ttl := quotaTTLFromSubscription(nil)
require.Equal(t, defaultQuotaTTL, ttl)
}
func TestQuotaTTLFromSubscription_ZeroSubscription(t *testing.T) {
zero := int64(0)
ttl := quotaTTLFromSubscription(&zero)
require.Equal(t, defaultQuotaTTL, ttl)
}
func TestQuotaTTLFromSubscription_ValidSubscription(t *testing.T) {
// Subscribed 10 days ago — next reset in ~20 days
sub := time.Now().Add(-10 * 24 * time.Hour).Unix()
ttl := quotaTTLFromSubscription(&sub)
require.Greater(t, ttl, 15*24*time.Hour) // at least 15 days
require.Less(t, ttl, 25*24*time.Hour+quotaTTLBuffer)
}
func TestNextMonthlyReset_SubscribedRecentPast(t *testing.T) {
// Subscribed on the 10th of this month (always valid day)
now := time.Now().UTC()
sub := time.Date(now.Year(), now.Month(), 10, 0, 0, 0, 0, time.UTC)
next := nextMonthlyReset(sub)
require.True(t, next.After(now) || next.Equal(now), "next reset should be in the future or now")
require.True(t, next.Before(now.AddDate(0, 1, 1)))
}
func TestNextMonthlyReset_SubscribedLongAgo(t *testing.T) {
// Subscribed 6 months ago on the 1st
sub := time.Now().UTC().AddDate(0, -6, 0)
sub = time.Date(sub.Year(), sub.Month(), 1, 0, 0, 0, 0, time.UTC)
next := nextMonthlyReset(sub)
require.True(t, next.After(time.Now().UTC()))
// Should be within the next 31 days
require.True(t, next.Before(time.Now().UTC().AddDate(0, 1, 1)))
}
func TestNextMonthlyReset_FutureSubscription(t *testing.T) {
sub := time.Now().UTC().AddDate(0, 0, 5)
next := nextMonthlyReset(sub)
require.True(t, next.After(time.Now().UTC()))
}
func TestAddMonthsClamped_Jan31ToFeb(t *testing.T) {
sub := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC)
next := addMonthsClamped(sub, 1)
require.Equal(t, time.Month(2), next.Month())
require.Equal(t, 28, next.Day()) // Feb 28 (2026 is not a leap year)
}
func TestAddMonthsClamped_Jan31ToFebLeapYear(t *testing.T) {
sub := time.Date(2028, 1, 31, 0, 0, 0, 0, time.UTC)
next := addMonthsClamped(sub, 1)
require.Equal(t, time.Month(2), next.Month())
require.Equal(t, 29, next.Day()) // Feb 29 (2028 is a leap year)
}
func TestAddMonthsClamped_Mar31ToApr(t *testing.T) {
sub := time.Date(2026, 3, 31, 0, 0, 0, 0, time.UTC)
next := addMonthsClamped(sub, 1)
require.Equal(t, time.Month(4), next.Month())
require.Equal(t, 30, next.Day()) // Apr has 30 days
}
func TestAddMonthsClamped_NormalDay(t *testing.T) {
sub := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
next := addMonthsClamped(sub, 1)
require.Equal(t, time.Month(2), next.Month())
require.Equal(t, 15, next.Day()) // no clamping needed
}
// --- Redis key ---
func TestQuotaRedisKey_Format(t *testing.T) {
key := quotaRedisKey("brave")
require.Equal(t, "websearch:quota:brave", key)
}
// --- isProviderAvailable ---
func TestIsProviderAvailable_EmptyAPIKey(t *testing.T) {
m := NewManager(nil, nil)
require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: ""}))
}
func TestIsProviderAvailable_Expired(t *testing.T) {
m := NewManager(nil, nil)
past := time.Now().Add(-1 * time.Hour).Unix()
require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &past}))
}
func TestIsProviderAvailable_Valid(t *testing.T) {
m := NewManager(nil, nil)
future := time.Now().Add(1 * time.Hour).Unix()
require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &future}))
require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k"})) // no expiry
}
// --- resolveProxyID ---
func TestResolveProxyID_AccountProxyOverrides(t *testing.T) {
cfg := ProviderConfig{ProxyID: 42}
require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080"))
require.Equal(t, int64(42), resolveProxyID(cfg, ""))
}
// --- isProxyError ---
func TestIsProxyError_Nil(t *testing.T) {
require.False(t, isProxyError(nil))
}
func TestIsProxyError_ConnectionRefused(t *testing.T) {
require.True(t, isProxyError(fmt.Errorf("dial tcp: connection refused")))
}
func TestIsProxyError_Timeout(t *testing.T) {
require.True(t, isProxyError(fmt.Errorf("i/o timeout while connecting to proxy")))
}
func TestIsProxyError_SOCKS(t *testing.T) {
require.True(t, isProxyError(fmt.Errorf("socks connect failed")))
}
func TestIsProxyError_TLSHandshake(t *testing.T) {
require.True(t, isProxyError(fmt.Errorf("tls handshake timeout")))
}
func TestIsProxyError_APIError_NotProxy(t *testing.T) {
require.False(t, isProxyError(fmt.Errorf("API rate limit exceeded")))
}
// --- isProxyAvailable (nil Redis) ---
func TestIsProxyAvailable_NilRedis(t *testing.T) {
m := NewManager(nil, nil)
require.True(t, m.isProxyAvailable(context.Background(), 42))
}
func TestIsProxyAvailable_ZeroID(t *testing.T) {
m := NewManager(nil, nil)
require.True(t, m.isProxyAvailable(context.Background(), 0))
}
// --- selectByQuotaWeight ---
func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) {
m := NewManager(nil, nil)
candidates := []ProviderConfig{
{Type: "brave", APIKey: "k1", QuotaLimit: 0},
{Type: "tavily", APIKey: "k2", QuotaLimit: 100},
}
result := m.selectByQuotaWeight(context.Background(), candidates)
require.Len(t, result, 2)
require.Equal(t, "tavily", result[0].Type)
require.Equal(t, "brave", result[1].Type)
}
func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) {
m := NewManager(nil, nil)
candidates := []ProviderConfig{
{Type: "brave", APIKey: "k1", QuotaLimit: 0},
{Type: "tavily", APIKey: "k2", QuotaLimit: 0},
}
result := m.selectByQuotaWeight(context.Background(), candidates)
require.Len(t, result, 2)
}
func TestSelectByQuotaWeight_Empty(t *testing.T) {
m := NewManager(nil, nil)
result := m.selectByQuotaWeight(context.Background(), nil)
require.Empty(t, result)
}
// --- newHTTPClient ---
func TestNewHTTPClient_NoProxy(t *testing.T) {
c, err := newHTTPClient("")
require.NoError(t, err)
require.NotNil(t, c)
}
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
_, err := newHTTPClient("://bad-url")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid proxy URL")
}
func TestNewHTTPClient_ValidHTTPProxy(t *testing.T) {
c, err := newHTTPClient("http://proxy.example.com:8080")
require.NoError(t, err)
require.NotNil(t, c)
}
func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) {
c, err := newHTTPClient("socks5://proxy.example.com:1080")
require.NoError(t, err)
require.NotNil(t, c)
}
// --- ResetUsage ---
func TestManager_ResetUsage_NilRedis(t *testing.T) {
m := NewManager(nil, nil)
err := m.ResetUsage(context.Background(), "brave")
require.NoError(t, err)
}

View File

@@ -0,0 +1,11 @@
package websearch
import "context"
// Provider is the interface every search backend must implement.
type Provider interface {
// Name returns the provider identifier ("brave" or "tavily").
Name() string
// Search executes a web search and returns results.
Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
}

View File

@@ -0,0 +1,107 @@
package websearch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
const (
tavilySearchEndpoint = "https://api.tavily.com/search"
tavilyProviderName = "tavily"
tavilySearchDepthBasic = "basic"
)
// TavilyProvider implements web search via the Tavily Search API.
type TavilyProvider struct {
apiKey string
httpClient *http.Client
}
// NewTavilyProvider creates a Tavily Search provider.
// The caller is responsible for configuring the http.Client with proxy/timeouts.
func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
}
func (t *TavilyProvider) Name() string { return tavilyProviderName }
func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
maxResults := req.MaxResults
if maxResults <= 0 {
maxResults = defaultMaxResults
}
payload := tavilyRequest{
APIKey: t.apiKey,
Query: req.Query,
MaxResults: maxResults,
SearchDepth: tavilySearchDepthBasic,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("tavily: encode request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("tavily: build request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := t.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("tavily: request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
if err != nil {
return nil, fmt.Errorf("tavily: read body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
}
var raw tavilyResponse
if err := json.Unmarshal(body, &raw); err != nil {
return nil, fmt.Errorf("tavily: decode response: %w", err)
}
results := make([]SearchResult, 0, len(raw.Results))
for _, r := range raw.Results {
results = append(results, SearchResult{
URL: r.URL,
Title: r.Title,
Snippet: r.Content,
})
}
return &SearchResponse{Results: results, Query: req.Query}, nil
}
type tavilyRequest struct {
APIKey string `json:"api_key"`
Query string `json:"query"`
MaxResults int `json:"max_results"`
SearchDepth string `json:"search_depth"`
}
type tavilyResponse struct {
Results []tavilyResult `json:"results"`
}
type tavilyResult struct {
URL string `json:"url"`
Title string `json:"title"`
Content string `json:"content"`
Score float64 `json:"score"`
}

View File

@@ -0,0 +1,63 @@
package websearch
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestTavilyProvider_Name(t *testing.T) {
p := NewTavilyProvider("key", nil)
require.Equal(t, "tavily", p.Name())
}
func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
// Verify tavilyRequest struct fields map correctly
req := tavilyRequest{
APIKey: "test-key",
Query: "golang",
MaxResults: 3,
SearchDepth: tavilySearchDepthBasic,
}
data, err := json.Marshal(req)
require.NoError(t, err)
var parsed map[string]any
require.NoError(t, json.Unmarshal(data, &parsed))
require.Equal(t, "test-key", parsed["api_key"])
require.Equal(t, "golang", parsed["query"])
require.Equal(t, float64(3), parsed["max_results"])
require.Equal(t, "basic", parsed["search_depth"])
}
func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
var resp tavilyResponse
require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
require.Len(t, resp.Results, 1)
require.Equal(t, "https://go.dev", resp.Results[0].URL)
require.Equal(t, "Go programming language", resp.Results[0].Content)
require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
// Verify mapping to SearchResult
results := make([]SearchResult, 0, len(resp.Results))
for _, r := range resp.Results {
results = append(results, SearchResult{
URL: r.URL, Title: r.Title, Snippet: r.Content,
})
}
require.Equal(t, "Go programming language", results[0].Snippet)
require.Equal(t, "", results[0].PageAge)
}
func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
var resp tavilyResponse
require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
require.Empty(t, resp.Results)
}
func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
var resp tavilyResponse
require.Error(t, json.Unmarshal([]byte("not json"), &resp))
}

View File

@@ -0,0 +1,30 @@
package websearch
// SearchResult represents a single web search result.
type SearchResult struct {
URL string `json:"url"`
Title string `json:"title"`
Snippet string `json:"snippet"`
PageAge string `json:"page_age,omitempty"`
}
// SearchRequest describes a web search to perform.
type SearchRequest struct {
Query string
MaxResults int // defaults to defaultMaxResults if <= 0
ProxyURL string // optional HTTP proxy URL
}
// SearchResponse holds the results of a web search.
type SearchResponse struct {
Results []SearchResult
Query string // the query that was actually executed
}
const defaultMaxResults = 5
// Provider type identifiers.
const (
ProviderTypeBrave = "brave"
ProviderTypeTavily = "tavily"
)

View File

@@ -138,10 +138,17 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
user.FieldEmail,
user.FieldUsername,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
user.FieldBalanceNotifyEnabled,
user.FieldBalanceNotifyThresholdType,
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -639,22 +646,31 @@ func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
out := &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
TotalRecharged: u.TotalRecharged,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
// Parse extra emails JSON (supports both old []string and new []NotifyEmailEntry format)
if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" {
out.BalanceNotifyExtraEmails = service.ParseNotifyEmails(u.BalanceNotifyExtraEmails)
}
return out
}
func groupEntityToService(g *dbent.Group) *service.Group {

View File

@@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
if err != nil {
return err
}
err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -67,17 +71,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
}
}
// 设置账号统计定价规则
if len(channel.AccountStatsPricingRules) > 0 {
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
var modelMappingJSON []byte
var modelMappingJSON, featuresConfigJSON []byte
err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at
FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
@@ -85,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
@@ -98,6 +110,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
}
ch.ModelPricing = pricing
statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
if err != nil {
return nil, err
}
ch.AccountStatsPricingRules = statsPricingRules
return ch, nil
}
@@ -107,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
if err != nil {
return err
}
result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
WHERE id = $7`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW()
WHERE id = $10`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -137,6 +159,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
}
}
// 更新账号统计定价规则
if channel.AccountStatsPricingRules != nil {
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
return err
}
}
return nil
})
}
@@ -187,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
)
@@ -203,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
var modelMappingJSON, featuresConfigJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -225,9 +255,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
if err != nil {
return nil, nil, err
}
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
}
}
@@ -273,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -284,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
var modelMappingJSON, featuresConfigJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -312,9 +348,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
return nil, err
}
// 批量加载账号统计定价规则
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
if err != nil {
return nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
}
return channels, nil
@@ -456,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
return m
}
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
if len(m) == 0 {
return []byte("{}"), nil
}
data, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal features_config: %w", err)
}
return data, nil
}
func unmarshalFeaturesConfig(data []byte) map[string]any {
if len(data) == 0 {
return nil
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return nil
}
return m
}
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {

View File

@@ -0,0 +1,244 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 账号统计定价规则 ---
// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
// 1. 查询规则
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
}
defer func() { _ = rows.Close() }()
var allRules []service.AccountStatsPricingRule
var ruleIDs []int64
for rows.Next() {
var rule service.AccountStatsPricingRule
if err := rows.Scan(
&rule.ID, &rule.ChannelID, &rule.Name,
pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
&rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
}
ruleIDs = append(ruleIDs, rule.ID)
allRules = append(allRules, rule)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
}
// 2. 批量加载规则的模型定价
pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
if err != nil {
return nil, err
}
// 3. 按 channelID 分组并关联定价
result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
for i := range allRules {
allRules[i].Pricing = pricingMap[allRules[i].ID]
result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
}
return result, nil
}
// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
if len(ruleIDs) == 0 {
return make(map[int64][]service.ChannelModelPricing), nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
pq.Array(ruleIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
for rows.Next() {
var p service.ChannelModelPricing
var ruleID int64
var modelsJSON []byte
if err := rows.Scan(
&p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats model pricing: %w", err)
}
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
p.Models = []string{}
}
pricingMap[ruleID] = append(pricingMap[ruleID], p)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
}
// Load intervals for all pricing entries.
var allPricingIDs []int64
for _, pricings := range pricingMap {
for _, p := range pricings {
allPricingIDs = append(allPricingIDs, p.ID)
}
}
if len(allPricingIDs) > 0 {
intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs)
if err != nil {
return nil, err
}
for ruleID, pricings := range pricingMap {
for i := range pricings {
pricings[i].Intervals = intervalsMap[pricings[i].ID]
}
pricingMap[ruleID] = pricings
}
}
return pricingMap, nil
}
// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
if err != nil {
return nil, err
}
return result[channelID], nil
}
// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
// CASCADE 会自动删除关联的 model_pricing
if _, err := tx.ExecContext(ctx,
`DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
); err != nil {
return fmt.Errorf("delete old account stats pricing rules: %w", err)
}
for i := range rules {
rules[i].ChannelID = channelID
if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
return fmt.Errorf("insert account stats pricing rule: %w", err)
}
}
return nil
}
// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
err := tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
if err != nil {
return fmt.Errorf("insert account stats pricing rule: %w", err)
}
for j := range rule.Pricing {
if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
return err
}
}
return nil
}
// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := pricing.Platform
err = tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
ruleID, platform, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil {
return fmt.Errorf("insert account stats model pricing: %w", err)
}
// Persist intervals (mirrors channel_pricing_intervals logic).
for i := range pricing.Intervals {
iv := &pricing.Intervals[i]
iv.PricingID = pricing.ID
if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil {
return err
}
}
return nil
}
// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry.
func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error {
return tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
iv.PerRequestPrice, iv.SortOrder,
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
}
// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries.
func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
if len(pricingIDs) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_account_stats_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
pq.Array(pricingIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err)
}
defer func() { _ = rows.Close() }()
result := make(map[int64][]service.PricingInterval)
for rows.Next() {
var iv service.PricingInterval
if err := rows.Scan(
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats pricing interval: %w", err)
}
result[iv.PricingID] = append(result[iv.PricingID], iv)
}
return result, rows.Err()
}

View File

@@ -331,6 +331,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) AS account_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
@@ -351,6 +352,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
cache_read_tokens,
total_cost,
actual_cost,
account_cost,
total_duration_ms,
active_users,
computed_at
@@ -364,6 +366,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
hourly.cache_read_tokens,
hourly.total_cost,
hourly.actual_cost,
hourly.account_cost,
hourly.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
@@ -378,6 +381,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
account_cost = EXCLUDED.account_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
@@ -399,6 +403,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(account_cost), 0) AS account_cost,
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
@@ -419,6 +424,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
cache_read_tokens,
total_cost,
actual_cost,
account_cost,
total_duration_ms,
active_users,
computed_at
@@ -432,6 +438,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
daily.cache_read_tokens,
daily.total_cost,
daily.actual_cost,
daily.account_cost,
daily.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
@@ -446,6 +453,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
account_cost = EXCLUDED.account_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at

View File

@@ -3,6 +3,8 @@ package repository
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -11,23 +13,33 @@ import (
const (
verifyCodeKeyPrefix = "verify_code:"
notifyVerifyKeyPrefix = "notify_verify:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
)
// verifyCodeKey generates the Redis key for email verification code.
// Email is lowercased for case-insensitive consistency.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
return verifyCodeKeyPrefix + strings.ToLower(email)
}
// notifyVerifyKey generates the Redis key for notify email verification code.
// Email is lowercased to prevent case-sensitive key mismatch (the business layer
// uses strings.EqualFold for comparison).
func notifyVerifyKey(email string) string {
return notifyVerifyKeyPrefix + strings.ToLower(email)
}
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email
return passwordResetKeyPrefix + strings.ToLower(email)
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email
return passwordResetSentAtKeyPrefix + strings.ToLower(email)
}
type emailCache struct {
@@ -106,3 +118,60 @@ func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email st
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}
// Notify email verification code methods
func (c *emailCache) GetNotifyVerifyCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
key := notifyVerifyKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data service.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetNotifyVerifyCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
key := notifyVerifyKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
key := notifyVerifyKey(email)
return c.rdb.Del(ctx, key).Err()
}
// User-level rate limiting for notify email verification codes
func notifyCodeUserRateKey(userID int64) string {
return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID)
}
func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
key := notifyCodeUserRateKey(userID)
count, err := c.rdb.Incr(ctx, key).Result()
if err != nil {
return 0, err
}
// Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE.
if err := c.rdb.Expire(ctx, key, window).Err(); err != nil {
return count, fmt.Errorf("expire notify code rate key: %w", err)
}
return count, nil
}
func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
key := notifyCodeUserRateKey(userID)
count, err := c.rdb.Get(ctx, key).Int64()
if err != nil {
return 0, err
}
return count, nil
}

View File

@@ -113,9 +113,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
if cmd.BalanceCost > 0 {
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
newBalance, err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost)
if err != nil {
return err
}
result.NewBalance = &newBalance
}
if cmd.APIKeyQuotaCost > 0 {
@@ -133,9 +135,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
quotaState, err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost)
if err != nil {
return err
}
result.QuotaState = quotaState
}
return nil
@@ -169,24 +173,22 @@ func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscrip
return service.ErrSubscriptionNotFound
}
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
res, err := tx.ExecContext(ctx, `
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) (float64, error) {
var newBalance float64
err := tx.QueryRowContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, amount, userID)
RETURNING balance
`, amount, userID).Scan(&newBalance)
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
if err != nil {
return err
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrUserNotFound
return newBalance, nil
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
@@ -240,7 +242,7 @@ func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKe
return nil
}
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) (*service.AccountQuotaState, error) {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
@@ -248,61 +250,71 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `+dailyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `+dailyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `+weeklyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `+weeklyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
COALESCE((extra->>'quota_limit')::numeric, 0),
COALESCE((extra->>'quota_daily_used')::numeric, 0),
COALESCE((extra->>'quota_daily_limit')::numeric, 0),
COALESCE((extra->>'quota_weekly_used')::numeric, 0),
COALESCE((extra->>'quota_weekly_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
return err
return nil, err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
var state service.AccountQuotaState
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
if err := rows.Scan(
&state.TotalUsed, &state.TotalLimit,
&state.DailyUsed, &state.DailyLimit,
&state.WeeklyUsed, &state.WeeklyLimit,
); err != nil {
return nil, err
}
} else {
if err := rows.Err(); err != nil {
return err
return nil, err
}
return service.ErrAccountNotFound
return nil, service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
return err
return nil, err
}
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return err
return nil, err
}
}
return nil
return &state, nil
}

View File

@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
"numeric", // account_stats_cost
"timestamptz", // created_at
}
@@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
@@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) AS (VALUES `)
@@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
)
SELECT
@@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*45)
args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
)
SELECT
@@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
model_mapping_chain,
billing_tier,
billing_mode,
account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
@@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
modelMappingChain,
billingTier,
billingMode,
log.AccountStatsCost, // account_stats_cost
createdAt,
},
}
@@ -1518,6 +1528,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(account_cost), 0) as total_account_cost,
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
FROM usage_dashboard_daily
`
@@ -1534,6 +1545,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.TotalAccountCost,
&totalDurationMs,
); err != nil {
return err
@@ -1552,6 +1564,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
cache_read_tokens as today_cache_read_tokens,
total_cost as today_cost,
actual_cost as today_actual_cost,
account_cost as today_account_cost,
active_users as active_users
FROM usage_dashboard_daily
WHERE bucket_date = $1::date
@@ -1568,6 +1581,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
&stats.TodayAccountCost,
&stats.ActiveUsers,
); err != nil {
if err != sql.ErrNoRows {
@@ -1603,6 +1617,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
cache_read_tokens,
total_cost,
actual_cost,
COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) AS account_cost,
COALESCE(duration_ms, 0) AS duration_ms
FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
@@ -1616,6 +1631,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_account_cost,
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
@@ -1623,7 +1639,8 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost,
COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_account_cost
FROM scoped
`
var totalDurationMs int64
@@ -1639,6 +1656,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.TotalAccountCost,
&totalDurationMs,
&stats.TodayRequests,
&stats.TodayInputTokens,
@@ -1647,6 +1665,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
&stats.TodayAccountCost,
); err != nil {
return err
}
@@ -1959,7 +1978,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -1989,7 +2008,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2026,7 +2045,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2585,7 +2604,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
COALESCE(SUM(actual_cost), 0) as actual_cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY model
@@ -2990,8 +3010,9 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
accountCostExpr := "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost"
modelExpr := resolveModelDimensionExpression(source)
query := fmt.Sprintf(`
@@ -3004,10 +3025,11 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, modelExpr, actualCostExpr)
`, modelExpr, actualCostExpr, accountCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -3062,7 +3084,8 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
COALESCE(SUM(ul.actual_cost), 0) as actual_cost,
COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
@@ -3113,6 +3136,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
&row.AccountCost,
); err != nil {
return nil, err
}
@@ -3133,7 +3157,8 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
COALESCE(SUM(ul.actual_cost), 0) as actual_cost,
COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs ul
LEFT JOIN users u ON u.id = ul.user_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
@@ -3204,6 +3229,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
&row.AccountCost,
); err != nil {
return nil, err
}
@@ -3358,7 +3384,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
@@ -3382,9 +3408,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
); err != nil {
return nil, err
}
if filters.AccountID > 0 {
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalAccountCost = &totalAccountCost
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
start := time.Unix(0, 0).UTC()
@@ -3433,7 +3457,7 @@ type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3500,7 +3524,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3591,7 +3615,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
@@ -4069,6 +4093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
accountStatsCost sql.NullFloat64
createdAt time.Time
)
@@ -4118,6 +4143,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&modelMappingChain,
&billingTier,
&billingMode,
&accountStatsCost,
&createdAt,
); err != nil {
return nil, err
@@ -4214,6 +4240,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
if accountStatsCost.Valid {
log.AccountStatsCost = &accountStatsCost.Float64
}
return log, nil
}
@@ -4257,6 +4286,7 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
&row.AccountCost,
); err != nil {
return nil, err
}

View File

@@ -753,8 +753,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch")
s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch")
s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch")
// account_cost falls back to total_cost when account_stats_cost is NULL
s.Require().Equal(baseStats.TotalAccountCost+2.3, stats.TotalAccountCost, "TotalAccountCost mismatch")
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
s.Require().GreaterOrEqual(stats.TodayAccountCost, 0.0, "expected TodayAccountCost >= 0")
wantRpm, wantTpm, err := s.repo.getPerformanceStats(s.ctx, 0)
s.Require().NoError(err, "getPerformanceStats")
@@ -833,6 +836,8 @@ func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
s.Require().Equal(int64(45), stats.TotalTokens)
s.Require().Equal(1.5, stats.TotalCost)
s.Require().Equal(1.4, stats.TotalActualCost)
// account_cost = COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) = total_cost
s.Require().Equal(1.5, stats.TotalAccountCost)
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
}

View File

@@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
@@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
@@ -299,7 +301,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost", "account_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err)
@@ -344,6 +346,93 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
require.NoError(t, err)
require.Equal(t, int64(1), stats.TotalRequests)
require.Equal(t, int64(9), stats.TotalTokens)
require.NotNil(t, stats.TotalAccountCost, "TotalAccountCost should always be returned")
require.Equal(t, 1.2, *stats.TotalAccountCost)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetModelStatsAccountCostColumn(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
mock.ExpectQuery("FROM usage_logs").
WithArgs(start, end).
WillReturnRows(sqlmock.NewRows([]string{
"model", "requests", "input_tokens", "output_tokens",
"cache_creation_tokens", "cache_read_tokens", "total_tokens",
"cost", "actual_cost", "account_cost",
}).
AddRow("claude-opus-4-6", int64(10), int64(100), int64(200), int64(5), int64(3), int64(308), 2.5, 2.0, 1.8).
AddRow("claude-sonnet-4-6", int64(5), int64(50), int64(100), int64(0), int64(0), int64(150), 1.0, 0.8, 0.7))
results, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, nil, nil, nil)
require.NoError(t, err)
require.Len(t, results, 2)
require.Equal(t, "claude-opus-4-6", results[0].Model)
require.Equal(t, 2.5, results[0].Cost)
require.Equal(t, 2.0, results[0].ActualCost)
require.Equal(t, 1.8, results[0].AccountCost)
require.Equal(t, "claude-sonnet-4-6", results[1].Model)
require.Equal(t, 0.7, results[1].AccountCost)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetGroupStatsAccountCostColumn(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
mock.ExpectQuery("FROM usage_logs").
WithArgs(start, end).
WillReturnRows(sqlmock.NewRows([]string{
"group_id", "group_name", "requests", "total_tokens",
"cost", "actual_cost", "account_cost",
}).
AddRow(int64(1), "azure-cc", int64(100), int64(5000), 10.0, 8.5, 7.2).
AddRow(int64(2), "max", int64(50), int64(2000), 5.0, 4.0, 3.5))
results, err := repo.GetGroupStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, nil, nil, nil)
require.NoError(t, err)
require.Len(t, results, 2)
require.Equal(t, int64(1), results[0].GroupID)
require.Equal(t, "azure-cc", results[0].GroupName)
require.Equal(t, 10.0, results[0].Cost)
require.Equal(t, 8.5, results[0].ActualCost)
require.Equal(t, 7.2, results[0].AccountCost)
require.Equal(t, int64(2), results[1].GroupID)
require.Equal(t, 3.5, results[1].AccountCost)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetStatsWithFiltersAlwaysReturnsAccountCost(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
// No AccountID filter set - TotalAccountCost should still be returned
filters := usagestats.UsageLogFilters{}
mock.ExpectQuery("FROM usage_logs").
WillReturnRows(sqlmock.NewRows([]string{
"total_requests", "total_input_tokens", "total_output_tokens",
"total_cache_tokens", "total_cost", "total_actual_cost",
"total_account_cost", "avg_duration_ms",
}).AddRow(int64(50), int64(1000), int64(2000), int64(100), 15.0, 12.5, 11.0, 100.0))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\)").
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\)").
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
mock.ExpectQuery("SELECT CONCAT\\(").
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err)
require.NotNil(t, stats.TotalAccountCost, "TotalAccountCost must always be returned, even without AccountID filter")
require.Equal(t, 11.0, *stats.TotalAccountCost)
require.NoError(t, mock.ExpectationsWereMet())
}
@@ -483,10 +572,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -530,10 +620,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -577,10 +668,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)

View File

@@ -100,7 +100,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
ORDER BY ugr.user_id
`

View File

@@ -137,7 +137,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient = r.client
}
updated, err := txClient.User.UpdateOneID(userIn.ID).
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
@@ -146,7 +146,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
Save(ctx)
SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
updated, err := updateOp.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
@@ -382,7 +390,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
// Track cumulative recharge amount for percentage-based notifications
if amount > 0 {
update = update.AddTotalRecharged(amount)
}
n, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
@@ -549,6 +562,11 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.UpdatedAt = src.UpdatedAt
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)

View File

@@ -58,6 +58,11 @@ func TestAPIContracts(t *testing.T) {
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z",
"balance_notify_enabled": false,
"balance_notify_threshold_type": "",
"balance_notify_threshold": null,
"balance_notify_extra_emails": null,
"total_recharged": 0,
"run_mode": "standard"
}
}`,
@@ -204,11 +209,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"claude_code_only": false,
"claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"allow_messages_dispatch": false,
"require_oauth_only": false,
"require_privacy_set": false,
"created_at": "2025-01-02T03:04:05Z",
@@ -587,26 +591,34 @@ func TestAPIContracts(t *testing.T) {
"enable_cch_signing": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_balance_recharge_multiplier": 0,
"payment_recharge_fee_rate": 0,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
"payment_enabled_types": null,
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
"custom_menu_items": [],
"custom_endpoints": []
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": []
}
}`,
},
@@ -699,7 +711,7 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo, nil, nil)
userService := service.NewUserService(userRepo, nil, nil, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()

View File

@@ -2,12 +2,15 @@
package server
import (
"context"
"log"
"log/slog"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -56,6 +59,42 @@ func ProvideRouter(
}
}
// Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) {
if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 {
service.SetWebSearchManager(nil)
return
}
configs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
for _, p := range cfg.Providers {
if p.APIKey == "" {
continue
}
pc := websearch.ProviderConfig{
Type: p.Type,
APIKey: p.APIKey,
QuotaLimit: derefInt64(p.QuotaLimit),
ExpiresAt: p.ExpiresAt,
}
if p.SubscribedAt != nil {
pc.SubscribedAt = p.SubscribedAt
}
if p.ProxyID != nil {
pc.ProxyID = *p.ProxyID
if u, ok := proxyURLs[*p.ProxyID]; ok {
pc.ProxyURL = u
} else {
// Proxy configured but not found — skip this provider to prevent direct connection.
slog.Warn("websearch: proxy not found for provider, skipping",
"provider", p.Type, "proxy_id", *p.ProxyID)
continue
}
}
configs = append(configs, pc)
}
service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
})
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
}
@@ -102,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
}
func derefInt64(p *int64) int64 {
if p == nil {
return 0
}
return *p
}

View File

@@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
return &clone, nil
},
}
userService := service.NewUserService(userRepo, nil, nil)
userService := service.NewUserService(userRepo, nil, nil, nil)
router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))

View File

@@ -41,7 +41,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
r := gin.New()

View File

@@ -18,6 +18,8 @@ const (
NonceTemplate = "__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
// StripeDomain is the domain for Stripe.js SDK
StripeDomain = "https://*.stripe.com"
)
// GenerateNonce generates a cryptographically secure random nonce.
@@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool {
strings.HasPrefix(path, "/responses")
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
// and Stripe.js domains. This allows the application to work correctly even if the
// config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {
// Add nonce placeholder to script-src if not present
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
@@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string {
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
}
// Add Stripe.js domain to script-src and frame-src if not present
if !strings.Contains(policy, "stripe.com") {
policy = addToDirective(policy, "script-src", StripeDomain)
policy = addToDirective(policy, "frame-src", StripeDomain)
}
return policy
}

View File

@@ -407,6 +407,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
// Web Search 模拟配置
adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation)
adminSettings.POST("/web-search-emulation/reset-usage", h.Admin.Setting.ResetWebSearchUsage)
}
}

View File

@@ -39,6 +39,7 @@ func RegisterPaymentRoutes(
orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders)
}
}

View File

@@ -26,6 +26,15 @@ func RegisterUserRoutes(
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
// 通知邮箱管理
notifyEmail := user.Group("/notify-email")
{
notifyEmail.POST("/send-code", h.User.SendNotifyEmailCode)
notifyEmail.POST("/verify", h.User.VerifyNotifyEmail)
notifyEmail.PUT("/toggle", h.User.ToggleNotifyEmail)
notifyEmail.DELETE("", h.User.RemoveNotifyEmail)
}
// TOTP 双因素认证
totp := user.Group("/totp")
{

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"hash/fnv"
"log/slog"
"reflect"
"sort"
"strconv"
@@ -969,7 +970,7 @@ func (a *Account) IsOveragesEnabled() bool {
return false
}
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用自动透传(仅替换认证)
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"
//
// 新字段accounts.extra.openai_passthrough。
// 兼容字段accounts.extra.openai_oauth_passthrough历史 OAuth 开关)。
@@ -1133,7 +1134,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级强制 HTTP开关。
// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
// 字段accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
@@ -1158,7 +1159,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
}
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用自动透传(仅替换认证)
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"
// 字段accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
@@ -1169,7 +1170,42 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
// WebSearch 模拟三态常量
const (
WebSearchModeDefault = "default" // 跟随渠道配置
WebSearchModeEnabled = "enabled" // 强制开启
WebSearchModeDisabled = "disabled" // 强制关闭
)
// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
// 三态default跟随渠道/ enabled强制开启/ disabled强制关闭
// 兼容旧 bool 值true→enabled, false→default并记录 debug 日志)。
func (a *Account) GetWebSearchEmulationMode() string {
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
return WebSearchModeDefault
}
raw := a.Extra[featureKeyWebSearchEmulation]
// Tolerant: legacy bool values (pre-migration or stale writes)
if b, ok := raw.(bool); ok {
slog.Debug("legacy bool web_search_emulation value", "account_id", a.ID, "value", b)
if b {
return WebSearchModeEnabled
}
return WebSearchModeDefault
}
mode, ok := raw.(string)
if !ok {
return WebSearchModeDefault
}
switch mode {
case WebSearchModeEnabled, WebSearchModeDisabled:
return mode
default:
return WebSearchModeDefault
}
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
// 字段accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool {
@@ -1395,6 +1431,19 @@ func (a *Account) getExtraTime(key string) time.Time {
return time.Time{}
}
// getExtraBool 从 Extra 中读取指定 key 的 bool 值
func (a *Account) getExtraBool(key string) bool {
if a.Extra == nil {
return false
}
if v, ok := a.Extra[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return false
}
// getExtraString 从 Extra 中读取指定 key 的字符串值
func (a *Account) getExtraString(key string) string {
if a.Extra == nil {
@@ -1408,6 +1457,14 @@ func (a *Account) getExtraString(key string) string {
return ""
}
// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal
func (a *Account) getExtraStringDefault(key, defaultVal string) string {
if v := a.getExtraString(key); v != "" {
return v
}
return defaultVal
}
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func (a *Account) getExtraInt(key string) int {
if a.Extra == nil {
@@ -1464,6 +1521,62 @@ func (a *Account) GetQuotaResetTimezone() string {
return "UTC"
}
// --- Quota Notification Getters ---
// QuotaNotifyConfig returns the notify configuration for a given quota dimension.
// dim must be one of quotaDimDaily, quotaDimWeekly, quotaDimTotal.
func (a *Account) QuotaNotifyConfig(dim string) (enabled bool, threshold float64, thresholdType string) {
enabled = a.getExtraBool("quota_notify_" + dim + "_enabled")
threshold = a.getExtraFloat64("quota_notify_" + dim + "_threshold")
thresholdType = a.getExtraStringDefault("quota_notify_"+dim+"_threshold_type", thresholdTypeFixed)
return
}
func (a *Account) GetQuotaNotifyDailyEnabled() bool {
e, _, _ := a.QuotaNotifyConfig(quotaDimDaily)
return e
}
func (a *Account) GetQuotaNotifyDailyThreshold() float64 {
_, t, _ := a.QuotaNotifyConfig(quotaDimDaily)
return t
}
func (a *Account) GetQuotaNotifyDailyThresholdType() string {
_, _, tt := a.QuotaNotifyConfig(quotaDimDaily)
return tt
}
func (a *Account) GetQuotaNotifyWeeklyEnabled() bool {
e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly)
return e
}
func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 {
_, t, _ := a.QuotaNotifyConfig(quotaDimWeekly)
return t
}
func (a *Account) GetQuotaNotifyWeeklyThresholdType() string {
_, _, tt := a.QuotaNotifyConfig(quotaDimWeekly)
return tt
}
func (a *Account) GetQuotaNotifyTotalEnabled() bool {
e, _, _ := a.QuotaNotifyConfig(quotaDimTotal)
return e
}
func (a *Account) GetQuotaNotifyTotalThreshold() float64 {
_, t, _ := a.QuotaNotifyConfig(quotaDimTotal)
return t
}
func (a *Account) GetQuotaNotifyTotalThresholdType() string {
_, _, tt := a.QuotaNotifyConfig(quotaDimTotal)
return tt
}
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)

View File

@@ -0,0 +1,236 @@
package service
import (
"context"
"strings"
)
// resolveAccountStatsCost 计算账号统计定价费用。
// 返回 nil 表示不覆盖使用默认公式total_cost × account_rate_multiplier
//
// 优先级(先命中为准):
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost
// 3. 模型定价文件LiteLLM中上游模型的默认价格
// 4. nil → 走默认公式total_cost × account_rate_multiplier
//
// upstreamModel 是最终发往上游的模型 ID。
// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
func resolveAccountStatsCost(
ctx context.Context,
channelService *ChannelService,
billingService *BillingService,
accountID int64,
groupID int64,
upstreamModel string,
tokens UsageTokens,
requestCount int,
totalCost float64,
) *float64 {
if channelService == nil || upstreamModel == "" {
return nil
}
channel, err := channelService.GetChannelForGroup(ctx, groupID)
if err != nil || channel == nil {
return nil
}
platform := channelService.GetGroupPlatform(ctx, groupID)
// 优先级 1自定义规则始终尝试
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
return cost
}
// 优先级 2渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
if channel.ApplyPricingToAccountStats {
cost := totalCost
if cost <= 0 {
return nil
}
return &cost
}
// 优先级 3模型定价文件LiteLLM默认价格
if billingService != nil {
return tryModelFilePricing(billingService, upstreamModel, tokens)
}
return nil
}
// tryModelFilePricing 使用模型定价文件LiteLLM/fallback中的标准价格计算费用。
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
pricing, err := billingService.GetModelPricing(model)
if err != nil || pricing == nil {
return nil
}
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
if cost <= 0 {
return nil
}
return &cost
}
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
func tryCustomRules(
channel *Channel, accountID, groupID int64,
platform, model string, tokens UsageTokens, requestCount int,
) *float64 {
modelLower := strings.ToLower(model)
for _, rule := range channel.AccountStatsPricingRules {
if !matchAccountStatsRule(&rule, accountID, groupID) {
continue
}
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
if pricing == nil {
continue // 规则匹配但模型不在规则定价中,继续下一条
}
return calculateStatsCost(pricing, tokens, requestCount)
}
return nil
}
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
// 匹配条件accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
return false
}
for _, id := range rule.AccountIDs {
if id == accountID {
return true
}
}
for _, id := range rule.GroupIDs {
if id == groupID {
return true
}
}
return false
}
// findPricingForModel 在定价列表中查找匹配的模型定价。
// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
// 精确匹配优先
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
continue
}
for _, m := range p.Models {
if strings.ToLower(m) == modelLower {
return p
}
}
}
// 通配符匹配:按配置顺序,先匹配先使用
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
continue
}
for _, m := range p.Models {
ml := strings.ToLower(m)
if !strings.HasSuffix(ml, "*") {
continue
}
prefix := strings.TrimSuffix(ml, "*")
if strings.HasPrefix(modelLower, prefix) {
return p
}
}
}
return nil
}
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
if queryPlatform == "" || pricingPlatform == "" {
return true
}
return queryPlatform == pricingPlatform
}
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
if pricing == nil {
return nil
}
switch pricing.BillingMode {
case BillingModePerRequest, BillingModeImage:
return calculatePerRequestStatsCost(pricing, requestCount)
default:
return calculateTokenStatsCost(pricing, tokens)
}
}
// calculatePerRequestStatsCost 按次/图片计费。
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
return nil
}
cost := *pricing.PerRequestPrice * float64(requestCount)
return &cost
}
// calculateTokenStatsCost Token 计费。
// If the pricing has intervals, find the matching interval by total token count
// and use its prices instead of the flat pricing fields.
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
p := pricing
if len(pricing.Intervals) > 0 {
totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
p = &ChannelModelPricing{
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
}
}
}
deref := func(ptr *float64) float64 {
if ptr == nil {
return 0
}
return *ptr
}
cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
float64(tokens.OutputTokens)*deref(p.OutputPrice) +
float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
if cost <= 0 {
return nil
}
return &cost
}
// applyAccountStatsCost resolves the account stats cost for a usage log entry.
// It resolves the upstream model (falling back to the requested model) and calls
// the 4-level priority chain via resolveAccountStatsCost.
func applyAccountStatsCost(
ctx context.Context,
usageLog *UsageLog,
cs *ChannelService, bs *BillingService,
accountID int64, groupID int64,
upstreamModel, requestedModel string,
tokens UsageTokens,
totalCost float64,
) {
model := upstreamModel
if model == "" {
model = requestedModel
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
)
}

View File

@@ -0,0 +1,771 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// matchAccountStatsRule
// ---------------------------------------------------------------------------
func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
rule := &AccountStatsPricingRule{}
require.False(t, matchAccountStatsRule(rule, 1, 10))
}
func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
require.True(t, matchAccountStatsRule(rule, 2, 999))
}
func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
require.True(t, matchAccountStatsRule(rule, 999, 20))
}
func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
rule := &AccountStatsPricingRule{
AccountIDs: []int64{1, 2},
GroupIDs: []int64{10, 20},
}
require.True(t, matchAccountStatsRule(rule, 2, 999))
}
func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
rule := &AccountStatsPricingRule{
AccountIDs: []int64{1, 2},
GroupIDs: []int64{10, 20},
}
require.True(t, matchAccountStatsRule(rule, 999, 10))
}
func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
rule := &AccountStatsPricingRule{
AccountIDs: []int64{1, 2},
GroupIDs: []int64{10, 20},
}
require.False(t, matchAccountStatsRule(rule, 999, 999))
}
// ---------------------------------------------------------------------------
// findPricingForModel
// ---------------------------------------------------------------------------
func TestFindPricingForModel(t *testing.T) {
exactPricing := ChannelModelPricing{
ID: 1,
Models: []string{"claude-opus-4"},
}
wildcardPricing := ChannelModelPricing{
ID: 2,
Models: []string{"claude-*"},
}
platformPricing := ChannelModelPricing{
ID: 3,
Platform: "openai",
Models: []string{"gpt-4o"},
}
emptyPlatformPricing := ChannelModelPricing{
ID: 4,
Models: []string{"gemini-2.5-pro"},
}
tests := []struct {
name string
list []ChannelModelPricing
platform string
model string
wantID int64
wantNil bool
}{
{
name: "exact match",
list: []ChannelModelPricing{exactPricing},
platform: "anthropic",
model: "claude-opus-4",
wantID: 1,
},
{
name: "exact match case insensitive",
list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
platform: "",
model: "claude-opus-4",
wantID: 5,
},
{
name: "wildcard match",
list: []ChannelModelPricing{wildcardPricing},
platform: "anthropic",
model: "claude-opus-4",
wantID: 2,
},
{
name: "exact match takes priority over wildcard",
list: []ChannelModelPricing{wildcardPricing, exactPricing},
platform: "anthropic",
model: "claude-opus-4",
wantID: 1,
},
{
name: "platform mismatch skipped",
list: []ChannelModelPricing{platformPricing},
platform: "anthropic",
model: "gpt-4o",
wantNil: true,
},
{
name: "empty platform in pricing matches any",
list: []ChannelModelPricing{emptyPlatformPricing},
platform: "gemini",
model: "gemini-2.5-pro",
wantID: 4,
},
{
name: "empty platform in query matches any pricing platform",
list: []ChannelModelPricing{platformPricing},
platform: "",
model: "gpt-4o",
wantID: 3,
},
{
name: "no match at all",
list: []ChannelModelPricing{exactPricing, wildcardPricing},
platform: "anthropic",
model: "gpt-4o",
wantNil: true,
},
{
name: "empty list returns nil",
list: nil,
model: "claude-opus-4",
wantNil: true,
},
{
name: "wildcard matches by config order (first match wins)",
list: []ChannelModelPricing{
{ID: 10, Models: []string{"claude-*"}},
{ID: 11, Models: []string{"claude-opus-*"}},
},
platform: "",
model: "claude-opus-4",
wantID: 10, // config order: "claude-*" is first and matches, so it wins
},
{
name: "shorter wildcard used when longer does not match",
list: []ChannelModelPricing{
{ID: 10, Models: []string{"claude-*"}},
{ID: 11, Models: []string{"claude-opus-*"}},
},
platform: "",
model: "claude-sonnet-4",
wantID: 10, // only "claude-*" matches
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findPricingForModel(tt.list, tt.platform, tt.model)
if tt.wantNil {
require.Nil(t, result)
return
}
require.NotNil(t, result)
require.Equal(t, tt.wantID, result.ID)
})
}
}
// ---------------------------------------------------------------------------
// calculateStatsCost
// ---------------------------------------------------------------------------
func TestCalculateStatsCost_NilPricing(t *testing.T) {
result := calculateStatsCost(nil, UsageTokens{}, 1)
require.Nil(t, result)
}
func TestCalculateStatsCost_TokenBilling(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require.InDelta(t, 0.2, *result, 1e-12)
}
func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
CacheWritePrice: testPtrFloat64(0.003),
CacheReadPrice: testPtrFloat64(0.0005),
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
CacheCreationTokens: 200,
CacheReadTokens: 300,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
require.InDelta(t, 0.95, *result, 1e-12)
}
func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
ImageOutputPrice: testPtrFloat64(0.01),
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
ImageOutputTokens: 10,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
require.InDelta(t, 0.3, *result, 1e-12)
}
func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
// OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
CacheCreationTokens: 200,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
// Only input contributes: 100*0.001 = 0.1
require.InDelta(t, 0.1, *result, 1e-12)
}
func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
}
tokens := UsageTokens{} // all zeros
result := calculateStatsCost(pricing, tokens, 1)
// totalCost == 0 → returns nil (does not override, falls back to default formula)
require.Nil(t, result)
}
func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.05),
}
tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
result := calculateStatsCost(pricing, tokens, 3)
require.NotNil(t, result)
// 0.05 * 3 = 0.15
require.InDelta(t, 0.15, *result, 1e-12)
}
func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModePerRequest,
// PerRequestPrice is nil
}
result := calculateStatsCost(pricing, UsageTokens{}, 1)
require.Nil(t, result)
}
func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0),
}
result := calculateStatsCost(pricing, UsageTokens{}, 1)
// price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
require.Nil(t, result)
}
func TestCalculateStatsCost_ImageBilling(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeImage,
PerRequestPrice: testPtrFloat64(0.10),
}
result := calculateStatsCost(pricing, UsageTokens{}, 2)
require.NotNil(t, result)
// 0.10 * 2 = 0.20
require.InDelta(t, 0.20, *result, 1e-12)
}
func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeImage,
// PerRequestPrice is nil
}
result := calculateStatsCost(pricing, UsageTokens{}, 1)
require.Nil(t, result)
}
func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
// BillingMode is empty string (default) → falls into token billing
pricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
require.InDelta(t, 0.2, *result, 1e-12)
}
// ---------------------------------------------------------------------------
// tryCustomRules — 多规则顺序测试
// ---------------------------------------------------------------------------
func TestTryCustomRules_FirstMatchWins(t *testing.T) {
channel := &Channel{
AccountStatsPricingRules: []AccountStatsPricingRule{
{
GroupIDs: []int64{1},
Pricing: []ChannelModelPricing{
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
},
},
{
GroupIDs: []int64{1},
Pricing: []ChannelModelPricing{
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
},
},
},
}
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
require.NotNil(t, result)
// 应使用第一条规则的价格100*0.01 + 50*0.02 = 2.0
require.InDelta(t, 2.0, *result, 1e-12)
}
func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
channel := &Channel{
AccountStatsPricingRules: []AccountStatsPricingRule{
{
AccountIDs: []int64{888}, // 不匹配
Pricing: []ChannelModelPricing{
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
},
},
{
GroupIDs: []int64{1}, // 匹配
Pricing: []ChannelModelPricing{
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
},
},
},
}
tokens := UsageTokens{InputTokens: 100}
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
require.NotNil(t, result)
// 跳过规则1账号不匹配使用规则2100*0.05 = 5.0
require.InDelta(t, 5.0, *result, 1e-12)
}
func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
channel := &Channel{
AccountStatsPricingRules: []AccountStatsPricingRule{
{
AccountIDs: []int64{888},
Pricing: []ChannelModelPricing{
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
},
},
},
}
tokens := UsageTokens{InputTokens: 100}
result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
require.Nil(t, result) // 账号和分组都不匹配
}
func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
channel := &Channel{
AccountStatsPricingRules: []AccountStatsPricingRule{
{
GroupIDs: []int64{1},
Pricing: []ChannelModelPricing{
{ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
},
},
{
GroupIDs: []int64{1},
Pricing: []ChannelModelPricing{
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
},
},
},
}
tokens := UsageTokens{InputTokens: 100}
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
require.NotNil(t, result)
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
}
// ---------------------------------------------------------------------------
// tryModelFilePricing
// ---------------------------------------------------------------------------
// newTestBillingServiceWithPrices creates a BillingService with pre-populated
// fallback prices for testing. No config or pricing service is needed.
// The key must match what getFallbackPricing resolves to for a given model name.
// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4".
func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService {
return &BillingService{
fallbackPrices: prices,
}
}
func TestTryModelFilePricing_Success(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
},
})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require.InDelta(t, 0.2, *result, 1e-12)
}
func TestTryModelFilePricing_PricingNotFound(t *testing.T) {
// "nonexistent-model" does not match any fallback pattern
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := tryModelFilePricing(bs, "nonexistent-model", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_NilFallback(t *testing.T) {
// getFallbackPricing returns nil when key maps to nil
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": nil,
})
tokens := UsageTokens{InputTokens: 100}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_ZeroCost(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
},
})
tokens := UsageTokens{} // all zero tokens → cost = 0 → nil
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_WithImageOutput(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
ImageOutputPricePerToken: 0.01,
},
})
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
ImageOutputTokens: 10,
}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
require.InDelta(t, 0.3, *result, 1e-12)
}
func TestTryModelFilePricing_WithCacheTokens(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
CacheCreationPricePerToken: 0.003,
CacheReadPricePerToken: 0.0005,
},
})
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
CacheCreationTokens: 200,
CacheReadTokens: 300,
}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
require.InDelta(t, 0.95, *result, 1e-12)
}
// ---------------------------------------------------------------------------
// resolveAccountStatsCost — integration tests covering the 4-level priority chain
// ---------------------------------------------------------------------------
func TestResolveAccountStatsCost_NilChannelService(t *testing.T) {
result := resolveAccountStatsCost(
context.Background(),
nil, // channelService is nil
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
1, 1, "claude-sonnet-4",
UsageTokens{InputTokens: 100}, 1, 0.5,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) {
cs := newTestChannelServiceForStats(t, &Channel{
ID: 1,
Status: StatusActive,
}, 1, "")
result := resolveAccountStatsCost(
context.Background(),
cs,
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
1, 1, "", // empty upstream model
UsageTokens{InputTokens: 100}, 1, 0.5,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) {
// Group 99 is NOT in the cache, so GetChannelForGroup returns nil
cs := newTestChannelServiceForStats(t, &Channel{
ID: 1,
Status: StatusActive,
}, 1, "")
result := resolveAccountStatsCost(
context.Background(),
cs,
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
1, 99, "claude-sonnet-4", // groupID 99 has no channel
UsageTokens{InputTokens: 100}, 1, 0.5,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
AccountStatsPricingRules: []AccountStatsPricingRule{
{
GroupIDs: []int64{10},
Pricing: []ChannelModelPricing{
{
ID: 100,
Models: []string{"claude-sonnet-4"},
InputPrice: testPtrFloat64(0.01),
OutputPrice: testPtrFloat64(0.02),
},
},
},
},
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := resolveAccountStatsCost(
context.Background(),
cs, nil, // billingService not needed when custom rule hits
1, 10, "claude-sonnet-4",
tokens, 1, 999.0, // totalCost ignored because custom rule hits
)
require.NotNil(t, result)
// 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0
require.InDelta(t, 2.0, *result, 1e-12)
}
func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: true,
// No custom rules
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := resolveAccountStatsCost(
context.Background(),
cs, nil,
1, 10, "claude-sonnet-4",
tokens, 1, 0.75, // totalCost = 0.75
)
require.NotNil(t, result)
require.InDelta(t, 0.75, *result, 1e-12)
}
func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: true,
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
result := resolveAccountStatsCost(
context.Background(),
cs, nil,
1, 10, "claude-sonnet-4",
UsageTokens{}, 1, 0.0, // totalCost = 0
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: false, // not enabled
// No custom rules
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
},
})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := resolveAccountStatsCost(
context.Background(),
cs, bs,
1, 10, "claude-sonnet-4",
tokens, 1, 999.0, // totalCost ignored
)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require.InDelta(t, 0.2, *result, 1e-12)
}
func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: false,
// No custom rules
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
// BillingService with no pricing for the model
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := resolveAccountStatsCost(
context.Background(),
cs, bs,
1, 10, "totally-unknown-model",
tokens, 1, 0.0,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: false,
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
result := resolveAccountStatsCost(
context.Background(),
cs, nil, // billingService is nil
1, 10, "claude-sonnet-4",
UsageTokens{InputTokens: 100}, 1, 0.0,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) {
// Both custom rule and ApplyPricingToAccountStats are configured;
// custom rule should take precedence.
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: true,
AccountStatsPricingRules: []AccountStatsPricingRule{
{
GroupIDs: []int64{10},
Pricing: []ChannelModelPricing{
{
ID: 100,
Models: []string{"claude-sonnet-4"},
InputPrice: testPtrFloat64(0.05),
},
},
},
},
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
tokens := UsageTokens{InputTokens: 100}
result := resolveAccountStatsCost(
context.Background(),
cs, nil,
1, 10, "claude-sonnet-4",
tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins)
)
require.NotNil(t, result)
// Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost)
require.InDelta(t, 5.0, *result, 1e-12)
}
// ---------------------------------------------------------------------------
// helpers for resolveAccountStatsCost tests
// ---------------------------------------------------------------------------
// newTestChannelServiceForStats creates a ChannelService with a single channel
// mapped to the given groupID, suitable for resolveAccountStatsCost tests.
func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService {
t.Helper()
cache := newEmptyChannelCache()
cache.channelByGroupID[groupID] = channel
cache.groupPlatform[groupID] = platform
cs := &ChannelService{}
cache.loadedAt = time.Now()
cs.cache.Store(cache)
return cs
}

View File

@@ -515,22 +515,10 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
account.RateLimitResetAt = resetAt
}
}
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if isOAuth && s.accountRepo != nil {
if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
account.RateLimitResetAt = resetAt
}
}
// 401 Unauthorized: 标记账号为永久错误
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))

View File

@@ -111,7 +111,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
@@ -138,10 +138,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T)
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, int64(88), repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.NotNil(t, account.RateLimitResetAt)
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
}
require.Zero(t, repo.rateLimitedID)
require.Nil(t, repo.rateLimitedAt)
require.Nil(t, account.RateLimitResetAt)
}

View File

@@ -499,7 +499,6 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
if account == nil {
return usage, nil
}
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
usage.FiveHour = progress
@@ -509,11 +508,8 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
}
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
mergeAccountExtra(account, updates)
if resetAt != nil {
account.RateLimitResetAt = resetAt
}
if usage.UpdatedAt == nil {
usage.UpdatedAt = &now
}
@@ -594,26 +590,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
return true
}
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
if account == nil || !account.IsOAuth() {
return nil, nil, nil
return nil, nil
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return nil, nil, fmt.Errorf("no access token available")
return nil, fmt.Errorf("no access token available")
}
modelID := openaipkg.DefaultTestModel
payload := createOpenAITestPayload(modelID, true)
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
}
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
if err != nil {
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
return nil, fmt.Errorf("create openai probe request: %w", err)
}
req.Host = "chatgpt.com"
req.Header.Set("Content-Type", "application/json")
@@ -642,67 +638,51 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
ResponseHeaderTimeout: 10 * time.Second,
})
if err != nil {
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
return nil, fmt.Errorf("build openai probe client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
updates, err := extractOpenAICodexProbeUpdates(resp)
if err != nil {
return nil, nil, err
return nil, err
}
if len(updates) > 0 || resetAt != nil {
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
return updates, resetAt, nil
if len(updates) > 0 {
s.persistOpenAICodexProbeSnapshot(account.ID, updates)
return updates, nil
}
return nil, nil, nil
return nil, nil
}
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any) {
if s == nil || s.accountRepo == nil || accountID <= 0 {
return
}
if len(updates) == 0 && resetAt == nil {
if len(updates) == 0 {
return
}
go func() {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
if len(updates) > 0 {
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}
if resetAt != nil {
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
}
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}()
}
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
if resp == nil {
return nil, nil, nil
return nil, nil
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
baseTime := time.Now()
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
if len(updates) > 0 {
return updates, resetAt, nil
}
return nil, resetAt, nil
return buildCodexUsageExtraUpdates(snapshot, time.Now()), nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
}
return nil, nil, nil
}
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
return updates, err
return nil, nil
}
func mergeAccountExtra(account *Account, updates map[string]any) {

View File

@@ -92,30 +92,7 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
}
}
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "604800")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
headers.Set("x-codex-secondary-window-minutes", "300")
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
if err != nil {
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
}
if len(updates) == 0 {
t.Fatal("expected codex probe updates from 429 headers")
}
if resetAt == nil {
t.Fatal("expected resetAt from exhausted codex headers")
}
}
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotOnlyUpdatesExtra(t *testing.T) {
t.Parallel()
repo := &accountUsageCodexProbeRepo{
@@ -123,12 +100,10 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
rateLimitCh: make(chan time.Time, 1),
}
svc := &AccountUsageService{accountRepo: repo}
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
}, &resetAt)
"codex_7d_reset_at": time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second).Format(time.RFC3339),
})
select {
case updates := <-repo.updateExtraCh:
@@ -136,16 +111,49 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe extra persistence timed out")
t.Fatal("等待 codex 探测快照写入 extra 超时")
}
select {
case got := <-repo.rateLimitCh:
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe rate limit persistence timed out")
t.Fatalf("不应将探测快照写入运行时限流状态: %v", got)
case <-time.After(200 * time.Millisecond):
}
}
func TestAccountUsageService_GetOpenAIUsage_DoesNotPromoteCodexExtraToRateLimit(t *testing.T) {
t.Parallel()
resetAt := time.Now().Add(6 * 24 * time.Hour).UTC().Truncate(time.Second)
repo := &accountUsageCodexProbeRepo{
rateLimitCh: make(chan time.Time, 1),
}
svc := &AccountUsageService{accountRepo: repo}
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_5h_used_percent": 1.0,
"codex_5h_reset_at": time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second).Format(time.RFC3339),
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
},
}
usage, err := svc.getOpenAIUsage(context.Background(), account)
if err != nil {
t.Fatalf("getOpenAIUsage() error = %v", err)
}
if usage.SevenDay == nil || usage.SevenDay.Utilization != 100.0 {
t.Fatalf("预期 7 天用量仍然可见,实际为 %#v", usage.SevenDay)
}
if account.RateLimitResetAt != nil {
t.Fatalf("不应让已耗尽的 codex extra 改写运行时限流状态: %v", account.RateLimitResetAt)
}
select {
case got := <-repo.rateLimitCh:
t.Fatalf("不应将已耗尽的 codex extra 持久化为运行时限流状态: %v", got)
case <-time.After(200 * time.Millisecond):
}
}

View File

@@ -0,0 +1,105 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetWebSearchEmulationMode_Enabled(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_Disabled(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"},
}
require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_Default(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "default"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
}
// bool true → tolerant fallback → enabled (not default)
require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: false},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) {
var a *Account
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: nil,
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_MissingField(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) {
a := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}

View File

@@ -1470,10 +1470,6 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
if err != nil {
return nil, 0, err
}
now := time.Now()
for i := range accounts {
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
}
return accounts, result.Total, nil
}

View File

@@ -65,14 +65,14 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct {
@@ -131,9 +131,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
@@ -158,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
panic("unexpected")
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct {

View File

@@ -12,12 +12,12 @@ import (
type accountRepoStubForClearAccountError struct {
mockAccountRepoForGemini
account *Account
clearErrorCalls int
clearRateLimitCalls int
clearAntigravityCalls int
account *Account
clearErrorCalls int
clearRateLimitCalls int
clearAntigravityCalls int
clearModelRateLimitCalls int
clearTempUnschedCalls int
clearTempUnschedCalls int
}
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
@@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
resetAt := time.Now().Add(5 * time.Minute)
repo := &accountRepoStubForClearAccountError{
account: &Account{
ID: 31,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "refresh failed",
RateLimitResetAt: &resetAt,
TempUnschedulableUntil: &until,
ID: 31,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusError,
ErrorMessage: "refresh failed",
RateLimitResetAt: &resetAt,
TempUnschedulableUntil: &until,
TempUnschedulableReason: "missing refresh token",
},
}

View File

@@ -34,6 +34,15 @@ type APIKeyAuthUserSnapshot struct {
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
// Balance notification fields (required for CheckBalanceAfterDeduction)
Email string `json:"email"`
Username string `json:"username"`
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"`
}
// APIKeyAuthGroupSnapshot 分组快照

View File

@@ -6,6 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"log/slog"
"math/rand/v2"
"time"
@@ -13,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
const apiKeyAuthSnapshotVersion = 3
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -99,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
slog.Warn("failed to start auth cache invalidation subscriber", "error", err)
}
}
@@ -219,11 +220,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
Role: apiKey.User.Role,
Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency,
ID: apiKey.User.ID,
Status: apiKey.User.Status,
Role: apiKey.User.Role,
Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency,
Email: apiKey.User.Email,
Username: apiKey.User.Username,
BalanceNotifyEnabled: apiKey.User.BalanceNotifyEnabled,
BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged,
},
}
if apiKey.Group != nil {
@@ -274,11 +282,18 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
Role: snapshot.User.Role,
Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency,
ID: snapshot.User.ID,
Status: snapshot.User.Status,
Role: snapshot.User.Role,
Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency,
Email: snapshot.User.Email,
Username: snapshot.User.Username,
BalanceNotifyEnabled: snapshot.User.BalanceNotifyEnabled,
BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged,
},
}
if snapshot.Group != nil {

View File

@@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil
}
func (s *emailCacheStub) GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error) {
return nil, nil
}
func (s *emailCacheStub) SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
@@ -107,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
return nil
}
func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
return 0, nil
}
func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
return 0, nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{

View File

@@ -0,0 +1,404 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an
// in-memory settings repo and a non-nil emailService so that the guard-clause
// nil-checks pass. The emailService is intentionally minimal — tests must
// avoid crossing scenarios that would actually dispatch emails.
func newBalanceNotifyServiceForTest() (*BalanceNotifyService, *mockSettingRepo) {
repo := newMockSettingRepo()
// EmailService is a concrete type; construct with the same repo so that
// any accidental fallback reads still succeed. Tests should not trigger a
// crossing that reaches SendEmail.
email := NewEmailService(repo, nil)
return NewBalanceNotifyService(email, repo, nil), repo
}
// ---------- guard clauses ----------
func TestCheckBalanceAfterDeduction_NilUser(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
// Should not panic.
s.CheckBalanceAfterDeduction(context.Background(), nil, 100, 50)
}
func TestCheckBalanceAfterDeduction_UserNotifyDisabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
u := &User{ID: 1, BalanceNotifyEnabled: false}
// Even with a crossing, disabled flag short-circuits.
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
}
func TestCheckBalanceAfterDeduction_GlobalDisabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
u := &User{ID: 1, BalanceNotifyEnabled: true}
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
}
func TestCheckBalanceAfterDeduction_ThresholdZero(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "0"
u := &User{ID: 1, BalanceNotifyEnabled: true}
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
}
func TestCheckBalanceAfterDeduction_UserThresholdOverride(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "100" // global default
customThreshold := 5.0
u := &User{
ID: 1,
BalanceNotifyEnabled: true,
BalanceNotifyThreshold: &customThreshold,
}
// User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not
// cross 5, so nothing fires (verified by absence of panic).
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
}
func TestCheckBalanceAfterDeduction_NoCrossingNotFired(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
u := &User{ID: 1, BalanceNotifyEnabled: true}
// 100 -> 95, both remain above threshold=10, no crossing.
s.CheckBalanceAfterDeduction(context.Background(), u, 100, 5)
// 5 -> 3, both already below threshold, no crossing (only fires on first
// cross from above-to-below).
s.CheckBalanceAfterDeduction(context.Background(), u, 5, 2)
}
// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ----------
func TestCheckAccountQuotaAfterIncrement_NilAccount(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
// Should not panic.
s.CheckAccountQuotaAfterIncrement(context.Background(), nil, 10, nil)
}
func TestCheckAccountQuotaAfterIncrement_ZeroCost(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
s.CheckAccountQuotaAfterIncrement(context.Background(), a, 0, nil)
}
func TestCheckAccountQuotaAfterIncrement_NegativeCost(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
s.CheckAccountQuotaAfterIncrement(context.Background(), a, -5, nil)
}
func TestCheckAccountQuotaAfterIncrement_GlobalDisabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
a := &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_notify_daily_enabled": true,
"quota_notify_daily_threshold": 100.0,
"quota_daily_limit": 1000.0,
"quota_daily_used": 950.0,
},
}
// Global disabled → no processing even if a dim would cross.
s.CheckAccountQuotaAfterIncrement(context.Background(), a, 100, nil)
}
// ---------- sanity: internal helpers still work ----------
func TestGetBalanceNotifyConfig_AllFields(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "12.5"
repo.data[SettingKeyBalanceLowNotifyRechargeURL] = "https://example.com/pay"
enabled, threshold, url := s.getBalanceNotifyConfig(context.Background())
require.True(t, enabled)
require.Equal(t, 12.5, threshold)
require.Equal(t, "https://example.com/pay", url)
}
func TestGetBalanceNotifyConfig_Disabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
enabled, _, _ := s.getBalanceNotifyConfig(context.Background())
require.False(t, enabled)
}
func TestGetBalanceNotifyConfig_InvalidThreshold(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "not-a-number"
enabled, threshold, _ := s.getBalanceNotifyConfig(context.Background())
require.True(t, enabled)
require.Equal(t, 0.0, threshold)
}
func TestIsAccountQuotaNotifyEnabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
// Missing key → false
require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
// Explicit "false"
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
// Explicit "true"
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "true"
require.True(t, s.isAccountQuotaNotifyEnabled(context.Background()))
}
func TestGetSiteName_FallsBackToDefault(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
name := s.getSiteName(context.Background())
require.Equal(t, defaultSiteName, name)
}
func TestGetSiteName_Configured(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeySiteName] = "My Site"
require.Equal(t, "My Site", s.getSiteName(context.Background()))
}
// ---------- crossedDownward ----------
func TestCrossedDownward_CrossesBelow(t *testing.T) {
// oldBalance > threshold, newBalance < threshold → true
require.True(t, crossedDownward(100, 5, 10))
}
func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) {
// oldBalance > threshold, newBalance == threshold → false (not below)
require.False(t, crossedDownward(100, 10, 10))
}
func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) {
// oldBalance == threshold, newBalance < threshold → true
// (at-or-above → below counts as a crossing)
require.True(t, crossedDownward(10, 5, 10))
}
func TestCrossedDownward_AlreadyBelow(t *testing.T) {
// oldBalance < threshold → false (already below, no new crossing)
require.False(t, crossedDownward(5, 3, 10))
}
func TestCrossedDownward_BothAbove(t *testing.T) {
// oldBalance > threshold, newBalance > threshold → false (no crossing)
require.False(t, crossedDownward(100, 50, 10))
}
func TestCrossedDownward_ZeroThreshold(t *testing.T) {
// threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives
// Typical case: positive balances should not fire when threshold is 0.
require.False(t, crossedDownward(10, 5, 0))
require.False(t, crossedDownward(0, 0, 0))
}
func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) {
// Edge case: newBalance goes negative with threshold=0.
require.True(t, crossedDownward(5, -1, 0))
}
func TestCrossedDownward_NegativeValues(t *testing.T) {
// Both already negative, threshold is positive → no crossing (already below).
require.False(t, crossedDownward(-5, -10, 10))
}
func TestCrossedDownward_LargeDecrement(t *testing.T) {
// A single large deduction crosses the threshold.
require.True(t, crossedDownward(1000, 0.5, 100))
}
func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) {
// A tiny deduction stays above threshold.
require.False(t, crossedDownward(100, 99.99, 10))
}
// ---------- checkQuotaDimCrossings ----------
func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// Empty dims → no crossing, no panic.
s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite")
s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: false, // disabled
threshold: 100,
thresholdType: thresholdTypeFixed,
currentUsed: 950,
limit: 1000,
},
}
// Disabled dimension should be skipped even if crossing would occur.
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 0, // zero threshold
thresholdType: thresholdTypeFixed,
currentUsed: 950,
limit: 1000,
},
}
// Zero threshold → skipped.
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
// currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing.
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 400,
thresholdType: thresholdTypeFixed,
currentUsed: 300,
limit: 1000,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
// currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing.
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 400,
thresholdType: thresholdTypeFixed,
currentUsed: 800,
limit: 1000,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200
// Negative resolved threshold → skipped.
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 1200,
thresholdType: thresholdTypeFixed,
currentUsed: 950,
limit: 1000,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700
// currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing.
dims := []quotaDim{
{
name: quotaDimWeekly,
enabled: true,
threshold: 30,
thresholdType: thresholdTypePercentage,
currentUsed: 500,
limit: 1000,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// limit=0 → resolvedThreshold returns 0 → skipped.
dims := []quotaDim{
{
name: quotaDimTotal,
enabled: true,
threshold: 100,
thresholdType: thresholdTypeFixed,
currentUsed: 50,
limit: 0,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// dim1: no crossing (both below effective threshold)
// dim2: disabled (skipped)
// dim3: zero threshold (skipped)
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 400,
thresholdType: thresholdTypeFixed,
currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below
limit: 1000,
},
{
name: quotaDimWeekly,
enabled: false,
threshold: 100,
thresholdType: thresholdTypeFixed,
currentUsed: 900,
limit: 1000,
},
{
name: quotaDimTotal,
enabled: true,
threshold: 0,
thresholdType: thresholdTypeFixed,
currentUsed: 500,
limit: 1000,
},
}
// None should trigger. No panic expected.
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}

View File

@@ -0,0 +1,147 @@
//go:build unit
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// These tests guard against fmt.Sprintf arg-count mismatches in the email
// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in
// the output, which these assertions will catch.
// ---------- buildBalanceLowEmailBody ----------
func TestBuildBalanceLowEmailBody_ContainsRequiredFields(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildBalanceLowEmailBody("Alice", 3.14, 10.0, "MySite", "")
// All substituted values should appear in the output.
require.Contains(t, body, "MySite")
require.Contains(t, body, "Alice")
require.Contains(t, body, "$3.14")
require.Contains(t, body, "$10.00")
// No fmt.Sprintf format error markers.
require.NotContains(t, body, "%!")
require.NotContains(t, body, "MISSING")
require.NotContains(t, body, "EXTRA")
}
func TestBuildBalanceLowEmailBody_WithRechargeURL(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildBalanceLowEmailBody("Bob", 5.0, 20.0, "Site", "https://example.com/pay")
// The recharge anchor element should appear with the URL.
require.Contains(t, body, `href="https://example.com/pay"`)
require.Contains(t, body, "立即充值")
require.NotContains(t, body, "%!")
}
func TestBuildBalanceLowEmailBody_RechargeURLEscaped(t *testing.T) {
s := &BalanceNotifyService{}
// Try a URL with characters that need HTML escaping.
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", `https://example.com/?a=1&b=<script>`)
// `&` and `<` should be escaped in the href.
require.Contains(t, body, "&amp;")
require.Contains(t, body, "&lt;script&gt;")
require.NotContains(t, body, "<script>")
}
func TestBuildBalanceLowEmailBody_NoRechargeURLOmitsButton(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", "")
// The anchor element should not be rendered (style class may still appear).
require.NotContains(t, body, `<a href`)
require.NotContains(t, body, "立即充值")
}
// ---------- buildQuotaAlertEmailBody ----------
func TestBuildQuotaAlertEmailBody_AllFieldsPresent(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(
42, // accountID
"acc-foo", // accountName
"anthropic", // platform
"日限额 / Daily", // dimLabel
750.50, // used
1000.0, // limit
249.50, // remaining
"$249.50", // thresholdDisplay
"MySite", // siteName
)
require.Contains(t, body, "MySite")
require.Contains(t, body, "#42")
require.Contains(t, body, "acc-foo")
require.Contains(t, body, "anthropic")
require.Contains(t, body, "Daily")
require.Contains(t, body, "$750.50")
require.Contains(t, body, "$1000.00")
require.Contains(t, body, "$249.50")
// No format error markers.
require.NotContains(t, body, "%!")
require.NotContains(t, body, "MISSING")
require.NotContains(t, body, "EXTRA")
}
func TestBuildQuotaAlertEmailBody_UnlimitedDisplay(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(
1, "n", "p", "dim",
100.0, 0.0, // limit=0 triggers unlimited branch
0.0, "30%", "Site",
)
require.Contains(t, body, "无限制")
require.Contains(t, body, "Unlimited")
}
func TestBuildQuotaAlertEmailBody_PercentageThresholdDisplay(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(
1, "n", "p", "dim",
700.0, 1000.0, 300.0,
"30%", // percentage-formatted threshold
"Site",
)
require.Contains(t, body, "30%")
require.NotContains(t, body, "%!")
}
func TestBuildQuotaAlertEmailBody_RemainingClampedAtZero(t *testing.T) {
// Even though caller is responsible for clamping, this test documents the
// display behavior with remaining=0.
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(
1, "n", "p", "dim",
1500.0, 1000.0, 0.0, // used > limit (over-quota)
"$100.00", "Site",
)
require.Contains(t, body, "$0.00")
}
// ---------- sanity checks on the CSS `%%` escape ----------
func TestBuildBalanceLowEmailBody_NoCSSFormatError(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", "")
// CSS `linear-gradient(135deg, #f59e0b 0%, #d97706 100%)` should appear with
// literal percent signs (from the %% escape in the template).
require.True(t,
strings.Contains(body, "0%") && strings.Contains(body, "100%"),
"CSS gradient percentages not rendered; got: %s", body)
}
func TestBuildQuotaAlertEmailBody_NoCSSFormatError(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(1, "n", "p", "d", 0, 0, 0, "$0.00", "Site")
require.True(t,
strings.Contains(body, "0%") && strings.Contains(body, "100%"),
"CSS gradient percentages not rendered; got: %s", body)
}

View File

@@ -0,0 +1,479 @@
package service
import (
"context"
"fmt"
"html"
"log/slog"
"strconv"
"strings"
"time"
)
const (
emailSendTimeout = 30 * time.Second
// Threshold type values
thresholdTypeFixed = "fixed"
thresholdTypePercentage = "percentage"
// Quota dimension labels
quotaDimDaily = "daily"
quotaDimWeekly = "weekly"
quotaDimTotal = "total"
defaultSiteName = "Sub2API"
)
// quotaDimLabels maps dimension names to display labels.
var quotaDimLabels = map[string]string{
quotaDimDaily: "日限额 / Daily",
quotaDimWeekly: "周限额 / Weekly",
quotaDimTotal: "总限额 / Total",
}
// AccountQuotaReader provides read access to account quota data.
type AccountQuotaReader interface {
GetByID(ctx context.Context, id int64) (*Account, error)
}
// BalanceNotifyService handles balance and quota threshold notifications.
type BalanceNotifyService struct {
emailService *EmailService
settingRepo SettingRepository
accountRepo AccountQuotaReader
}
// NewBalanceNotifyService creates a new BalanceNotifyService.
func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountQuotaReader) *BalanceNotifyService {
return &BalanceNotifyService{
emailService: emailService,
settingRepo: settingRepo,
accountRepo: accountRepo,
}
}
// resolveBalanceThreshold returns the effective balance threshold.
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
if thresholdType == thresholdTypePercentage && totalRecharged > 0 {
return totalRecharged * threshold / 100
}
return threshold
}
// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, user *User, oldBalance, cost float64) {
if !s.canNotifyBalance(user) {
return
}
effectiveThreshold, rechargeURL, ok := s.resolveUserEffectiveThreshold(ctx, user)
if !ok {
return
}
newBalance := oldBalance - cost
if !crossedDownward(oldBalance, newBalance, effectiveThreshold) {
return
}
s.dispatchBalanceLowEmail(ctx, user, newBalance, effectiveThreshold, rechargeURL)
}
// canNotifyBalance checks nil guards and user-level toggle.
func (s *BalanceNotifyService) canNotifyBalance(user *User) bool {
if user == nil || s.emailService == nil || s.settingRepo == nil {
return false
}
return user.BalanceNotifyEnabled
}
// resolveUserEffectiveThreshold reads global + user config, returns the effective threshold.
// Returns ok=false when notifications should be skipped.
func (s *BalanceNotifyService) resolveUserEffectiveThreshold(ctx context.Context, user *User) (effectiveThreshold float64, rechargeURL string, ok bool) {
globalEnabled, globalThreshold, rechargeURL := s.getBalanceNotifyConfig(ctx)
if !globalEnabled {
return 0, "", false
}
threshold := globalThreshold
if user.BalanceNotifyThreshold != nil {
threshold = *user.BalanceNotifyThreshold
}
if threshold <= 0 {
return 0, "", false
}
effectiveThreshold = resolveBalanceThreshold(threshold, user.BalanceNotifyThresholdType, user.TotalRecharged)
if effectiveThreshold <= 0 {
return 0, "", false
}
return effectiveThreshold, rechargeURL, true
}
// crossedDownward returns true when oldV was at-or-above threshold but newV dropped below it.
func crossedDownward(oldV, newV, threshold float64) bool {
return oldV >= threshold && newV < threshold
}
// dispatchBalanceLowEmail collects recipients and sends the alert in a goroutine.
func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user *User, newBalance, threshold float64, rechargeURL string) {
siteName := s.getSiteName(ctx)
recipients := s.collectBalanceNotifyRecipients(user)
slog.Info("CheckBalanceAfterDeduction: sending notification",
"user_id", user.ID, "recipients", recipients, "new_balance", newBalance, "threshold", threshold)
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in balance notification", "recover", r)
}
}()
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
}()
}
// quotaDim describes one quota dimension for notification checking.
type quotaDim struct {
name string
enabled bool
threshold float64
thresholdType string // "fixed" (default) or "percentage"
currentUsed float64
limit float64
}
// resolvedThreshold converts the user-facing "remaining" threshold into a usage-based trigger point.
// The threshold represents how much quota REMAINS when the alert fires:
// - Fixed ($): threshold=400, limit=1000 → fires when usage reaches 600 (remaining drops to 400)
// - Percentage (%): threshold=30, limit=1000 → fires when usage reaches 700 (remaining drops to 30%)
func (d quotaDim) resolvedThreshold() float64 {
if d.limit <= 0 {
return 0
}
if d.thresholdType == thresholdTypePercentage {
return d.limit * (1 - d.threshold/100)
}
return d.limit - d.threshold
}
// buildQuotaDims returns the three quota dimensions for notification checking.
func buildQuotaDims(account *Account) []quotaDim {
return []quotaDim{
{quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()},
{quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()},
{quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), account.GetQuotaUsed(), account.GetQuotaLimit()},
}
}
// buildQuotaDimsFromState builds quota dimensions using DB transaction state instead of account snapshot.
// Notification settings (enabled, threshold, thresholdType) come from the account; usage values from quotaState.
func buildQuotaDimsFromState(account *Account, state *AccountQuotaState) []quotaDim {
return []quotaDim{
{quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), state.DailyUsed, state.DailyLimit},
{quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), state.WeeklyUsed, state.WeeklyLimit},
{quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), state.TotalUsed, state.TotalLimit},
}
}
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
// When quotaState is non-nil (from DB transaction RETURNING), it is used directly for threshold
// checking, avoiding a separate DB read. Otherwise it falls back to fetching fresh account data.
func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64, quotaState *AccountQuotaState) {
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
return
}
if !s.isAccountQuotaNotifyEnabled(ctx) {
return
}
adminEmails := s.getAccountQuotaNotifyEmails(ctx)
if len(adminEmails) == 0 {
return
}
siteName := s.getSiteName(ctx)
var dims []quotaDim
if quotaState != nil {
dims = buildQuotaDimsFromState(account, quotaState)
} else {
freshAccount := s.fetchFreshAccount(ctx, account)
dims = buildQuotaDims(freshAccount)
account = freshAccount // use fresh data for alert metadata
}
s.checkQuotaDimCrossings(account, dims, cost, adminEmails, siteName)
}
// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *Account) *Account {
if s.accountRepo == nil {
return snapshot
}
fresh, err := s.accountRepo.GetByID(ctx, snapshot.ID)
if err != nil {
slog.Warn("failed to fetch fresh account for quota notify, using snapshot",
"account_id", snapshot.ID, "error", err)
return snapshot
}
return fresh
}
// checkQuotaDimCrossings iterates pre-built quota dimensions and sends alerts for threshold crossings.
// Pre-increment value is reconstructed as currentUsed - cost to detect the crossing moment.
func (s *BalanceNotifyService) checkQuotaDimCrossings(account *Account, dims []quotaDim, cost float64, adminEmails []string, siteName string) {
for _, dim := range dims {
if !dim.enabled || dim.threshold <= 0 {
continue
}
effectiveThreshold := dim.resolvedThreshold()
if effectiveThreshold <= 0 {
continue
}
newUsed := dim.currentUsed
oldUsed := dim.currentUsed - cost
if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
s.asyncSendQuotaAlert(adminEmails, account.ID, account.Name, account.Platform, dim, newUsed, effectiveThreshold, siteName)
}
}
}
// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery.
func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, newUsed, effectiveThreshold float64, siteName string) {
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in quota notification", "recover", r)
}
}()
s.sendQuotaAlertEmails(adminEmails, accountID, accountName, platform, dim, newUsed, siteName)
}()
}
// getBalanceNotifyConfig reads global balance notification settings.
func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64, rechargeURL string) {
keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold, SettingKeyBalanceLowNotifyRechargeURL}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return false, 0, ""
}
enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
threshold = f
}
}
rechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
return
}
// isAccountQuotaNotifyEnabled checks the global account quota notification toggle.
func (s *BalanceNotifyService) isAccountQuotaNotifyEnabled(ctx context.Context) bool {
val, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEnabled)
if err != nil {
return false
}
return val == "true"
}
// getAccountQuotaNotifyEmails reads admin notification emails from settings,
// filtering out disabled and unverified entries.
func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails)
if err != nil || strings.TrimSpace(raw) == "" || raw == "[]" {
return nil
}
entries := ParseNotifyEmails(raw)
if len(entries) == 0 {
return nil
}
return filterVerifiedEmails(entries)
}
// getSiteName reads site name from settings with fallback.
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
if err != nil || name == "" {
return defaultSiteName
}
return name
}
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
func filterVerifiedEmails(entries []NotifyEmailEntry) []string {
var recipients []string
seen := make(map[string]bool)
for _, entry := range entries {
if entry.Disabled || !entry.Verified {
continue
}
email := strings.TrimSpace(entry.Email)
if email == "" {
continue
}
lower := strings.ToLower(email)
if seen[lower] {
continue
}
seen[lower] = true
recipients = append(recipients, email)
}
return recipients
}
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
// Only emails with verified=true and disabled=false are included.
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
return filterVerifiedEmails(user.BalanceNotifyExtraEmails)
}
// sendEmails sends an email to all recipients with shared timeout and error logging.
func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) {
if len(recipients) == 0 {
slog.Warn("sendEmails: no recipients", "subject", subject)
return
}
for _, to := range recipients {
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
if err := s.emailService.SendEmail(ctx, to, subject, body); err != nil {
attrs := append([]any{"to", to, "error", err}, logAttrs...)
slog.Error("failed to send notification", attrs...)
} else {
slog.Info("notification email sent successfully", "to", to, "subject", subject)
}
cancel()
}
}
// sendBalanceLowEmails sends balance low notification to all recipients.
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
displayName := userName
if displayName == "" {
displayName = userEmail
}
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL)
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
}
// sendQuotaAlertEmails sends quota alert notification to admin emails.
func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, used float64, siteName string) {
dimLabel := quotaDimLabels[dim.name]
if dimLabel == "" {
dimLabel = dim.name
}
// Format the remaining-based threshold for display
thresholdDisplay := fmt.Sprintf("$%.2f", dim.threshold)
if dim.thresholdType == thresholdTypePercentage {
thresholdDisplay = fmt.Sprintf("%.0f%%", dim.threshold)
}
remaining := dim.limit - used
if remaining < 0 {
remaining = 0
}
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName))
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name)
}
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
func sanitizeEmailHeader(s string) string {
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
}
// balanceLowEmailTemplate is the HTML template for balance low notifications.
// Format args: siteName, userName, userName, balance, threshold, threshold.
// The recharge button is appended dynamically when rechargeURL is set.
const balanceLowEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.balance { font-size: 36px; font-weight: bold; color: #dc2626; margin: 20px 0; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.recharge-btn { display: inline-block; margin-top: 24px; padding: 12px 32px; background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: #fff; text-decoration: none; border-radius: 6px; font-size: 16px; font-weight: bold; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333;">%s您的余额不足</p>
<p style="color: #666;">Dear %s, your balance is running low</p>
<div class="balance">$%.2f</div>
<div class="info">
<p>您的账户余额已低于提醒阈值 <strong>$%.2f</strong>。</p>
<p>Your account balance has fallen below the alert threshold of <strong>$%.2f</strong>.</p>
<p>请及时充值以免服务中断。</p>
<p>Please top up to avoid service interruption.</p>
</div>
%s
</div>
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`
// quotaAlertEmailTemplate is the HTML template for account quota alert notifications.
// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay.
const quotaAlertEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #ef4444 0%%, #dc2626 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; }
.metric { display: flex; justify-content: space-between; padding: 12px 0; border-bottom: 1px solid #eee; }
.metric-label { color: #666; }
.metric-value { font-weight: bold; color: #333; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; text-align: center; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333; text-align: center;">账号限额告警 / Account Quota Alert</p>
<div class="metric"><span class="metric-label">账号 ID / Account ID</span><span class="metric-value">#%d</span></div>
<div class="metric"><span class="metric-label">账号 / Account</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">平台 / Platform</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">维度 / Dimension</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">已使用 / Used</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">限额 / Limit</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">剩余额度 / Remaining</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">提醒阈值 / Alert Threshold</span><span class="metric-value">%s</span></div>
<div class="info">
<p>账号剩余额度已低于提醒阈值,请及时关注。</p>
<p>Account remaining quota has fallen below the alert threshold.</p>
</div>
</div>
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
</div>
</body>
</html>`
// buildBalanceLowEmailBody builds HTML email for balance low notification.
func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName, rechargeURL string) string {
rechargeBlock := ""
if rechargeURL != "" {
rechargeBlock = fmt.Sprintf(`<a href="%s" class="recharge-btn">立即充值 / Top Up Now</a>`, html.EscapeString(rechargeURL))
}
return fmt.Sprintf(balanceLowEmailTemplate, siteName, userName, userName, balance, threshold, threshold, rechargeBlock)
}
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, accountName, platform, dimLabel string, used, limit, remaining float64, thresholdDisplay, siteName string) string {
limitStr := fmt.Sprintf("$%.2f", limit)
if limit <= 0 {
limitStr = "无限制 / Unlimited"
}
return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay)
}

View File

@@ -0,0 +1,280 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ---------- resolveBalanceThreshold ----------
func TestResolveBalanceThreshold_Fixed(t *testing.T) {
// Fixed type always returns the raw threshold regardless of totalRecharged.
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypeFixed, 1000))
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypeFixed, 0))
require.Equal(t, 0.0, resolveBalanceThreshold(0, thresholdTypeFixed, 1000))
}
func TestResolveBalanceThreshold_Percentage(t *testing.T) {
// 10% of 1000 = 100
require.Equal(t, 100.0, resolveBalanceThreshold(10, thresholdTypePercentage, 1000))
// 50% of 200 = 100
require.Equal(t, 100.0, resolveBalanceThreshold(50, thresholdTypePercentage, 200))
}
func TestResolveBalanceThreshold_PercentageZeroRecharged(t *testing.T) {
// When totalRecharged is 0, percentage falls through to raw threshold
// (treated as fixed). This is the defensive behavior.
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypePercentage, 0))
}
func TestResolveBalanceThreshold_EmptyType(t *testing.T) {
// Empty type is treated as fixed (not percentage).
require.Equal(t, 10.0, resolveBalanceThreshold(10, "", 1000))
}
// ---------- quotaDim.resolvedThreshold ----------
func TestResolvedThreshold_FixedNormal(t *testing.T) {
// threshold=400 remaining, limit=1000 → usage trigger at 600
d := quotaDim{threshold: 400, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, 600.0, d.resolvedThreshold())
}
func TestResolvedThreshold_FixedThresholdExceedsLimit(t *testing.T) {
// threshold=1200, limit=1000 → returns negative, callers must skip
d := quotaDim{threshold: 1200, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, -200.0, d.resolvedThreshold())
}
func TestResolvedThreshold_FixedThresholdEqualsLimit(t *testing.T) {
// threshold=1000, limit=1000 → returns 0 (alert fires at 0 usage)
d := quotaDim{threshold: 1000, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, 0.0, d.resolvedThreshold())
}
func TestResolvedThreshold_PercentageNormal(t *testing.T) {
// threshold=30%, limit=1000 → usage trigger at 700 (remaining drops to 30%)
d := quotaDim{threshold: 30, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 700.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageZeroPercent(t *testing.T) {
// threshold=0%, limit=1000 → fires when remaining drops to 0 (usage=1000)
d := quotaDim{threshold: 0, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 1000.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageHundredPercent(t *testing.T) {
// threshold=100%, limit=1000 → fires immediately (remaining drops to 100% i.e. nothing used yet)
d := quotaDim{threshold: 100, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 0.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageOverHundred(t *testing.T) {
// threshold=150%, limit=1000 → returns negative (never triggers; callers skip)
d := quotaDim{threshold: 150, thresholdType: thresholdTypePercentage, limit: 1000}
require.Less(t, d.resolvedThreshold(), 0.0)
}
func TestResolvedThreshold_ZeroLimit(t *testing.T) {
// limit=0 → returns 0 to avoid division and false alerts on unlimited quotas
d := quotaDim{threshold: 100, thresholdType: thresholdTypeFixed, limit: 0}
require.Equal(t, 0.0, d.resolvedThreshold())
}
func TestResolvedThreshold_NegativeLimit(t *testing.T) {
// Negative limit treated as 0
d := quotaDim{threshold: 100, thresholdType: thresholdTypeFixed, limit: -10}
require.Equal(t, 0.0, d.resolvedThreshold())
}
// ---------- sanitizeEmailHeader ----------
func TestSanitizeEmailHeader_CRLF(t *testing.T) {
require.Equal(t, "Subject injected", sanitizeEmailHeader("Subject\r\n injected"))
}
func TestSanitizeEmailHeader_OnlyCR(t *testing.T) {
require.Equal(t, "foobar", sanitizeEmailHeader("foo\rbar"))
}
func TestSanitizeEmailHeader_OnlyLF(t *testing.T) {
require.Equal(t, "foobar", sanitizeEmailHeader("foo\nbar"))
}
func TestSanitizeEmailHeader_Clean(t *testing.T) {
require.Equal(t, "Sub2API", sanitizeEmailHeader("Sub2API"))
}
func TestSanitizeEmailHeader_Empty(t *testing.T) {
require.Equal(t, "", sanitizeEmailHeader(""))
}
func TestSanitizeEmailHeader_MultipleNewlines(t *testing.T) {
require.Equal(t, "abc", sanitizeEmailHeader("a\r\nb\r\nc"))
}
// ---------- buildQuotaDims ----------
func TestBuildQuotaDims_AllDimensionsReturned(t *testing.T) {
// Use an account with quota notify config across all 3 dimensions.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_notify_daily_enabled": true,
"quota_notify_daily_threshold": 100.0,
"quota_notify_daily_threshold_type": thresholdTypeFixed,
"quota_notify_weekly_enabled": true,
"quota_notify_weekly_threshold": 20.0,
"quota_notify_weekly_threshold_type": thresholdTypePercentage,
"quota_notify_total_enabled": false,
"quota_daily_limit": 500.0,
"quota_weekly_limit": 2000.0,
"quota_limit": 10000.0,
"quota_daily_used": 50.0,
"quota_weekly_used": 300.0,
"quota_used": 1000.0,
},
}
dims := buildQuotaDims(a)
require.Len(t, dims, 3)
// Daily
require.Equal(t, quotaDimDaily, dims[0].name)
require.True(t, dims[0].enabled)
require.Equal(t, 100.0, dims[0].threshold)
require.Equal(t, thresholdTypeFixed, dims[0].thresholdType)
require.Equal(t, 500.0, dims[0].limit)
require.Equal(t, 50.0, dims[0].currentUsed)
// Weekly
require.Equal(t, quotaDimWeekly, dims[1].name)
require.True(t, dims[1].enabled)
require.Equal(t, 20.0, dims[1].threshold)
require.Equal(t, thresholdTypePercentage, dims[1].thresholdType)
require.Equal(t, 2000.0, dims[1].limit)
// Total
require.Equal(t, quotaDimTotal, dims[2].name)
require.False(t, dims[2].enabled)
require.Equal(t, 10000.0, dims[2].limit)
require.Equal(t, 1000.0, dims[2].currentUsed)
}
func TestBuildQuotaDims_EmptyExtra(t *testing.T) {
// Missing fields default to zero/disabled.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{},
}
dims := buildQuotaDims(a)
require.Len(t, dims, 3)
for _, d := range dims {
require.False(t, d.enabled)
require.Equal(t, 0.0, d.threshold)
require.Equal(t, 0.0, d.limit)
}
}
// ---------- buildQuotaDimsFromState ----------
func TestBuildQuotaDimsFromState_UsesStateValues(t *testing.T) {
// Usage values should come from the state, not the account.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_notify_daily_enabled": true,
"quota_notify_daily_threshold": 100.0,
"quota_daily_used": 999.0, // should be ignored
"quota_daily_limit": 999.0, // should be ignored
},
}
state := &AccountQuotaState{
DailyUsed: 77.0,
DailyLimit: 500.0,
WeeklyUsed: 88.0,
WeeklyLimit: 2000.0,
TotalUsed: 99.0,
TotalLimit: 10000.0,
}
dims := buildQuotaDimsFromState(a, state)
require.Len(t, dims, 3)
// Settings from account (enabled, threshold, thresholdType)
require.True(t, dims[0].enabled)
require.Equal(t, 100.0, dims[0].threshold)
// Usage from state
require.Equal(t, 77.0, dims[0].currentUsed)
require.Equal(t, 500.0, dims[0].limit)
require.Equal(t, 88.0, dims[1].currentUsed)
require.Equal(t, 2000.0, dims[1].limit)
require.Equal(t, 99.0, dims[2].currentUsed)
require.Equal(t, 10000.0, dims[2].limit)
}
// ---------- collectBalanceNotifyRecipients ----------
func TestCollectBalanceNotifyRecipients_Empty(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{BalanceNotifyExtraEmails: nil}
require.Empty(t, s.collectBalanceNotifyRecipients(u))
}
func TestCollectBalanceNotifyRecipients_FiltersDisabledAndUnverified(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: "a@example.com", Verified: true, Disabled: false},
{Email: "b@example.com", Verified: true, Disabled: true}, // disabled
{Email: "c@example.com", Verified: false, Disabled: false}, // unverified
{Email: "d@example.com", Verified: true, Disabled: false},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"a@example.com", "d@example.com"}, got)
}
func TestCollectBalanceNotifyRecipients_DeduplicatesCaseInsensitive(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: "User@Example.com", Verified: true},
{Email: "user@example.com", Verified: true},
{Email: "USER@EXAMPLE.COM", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Len(t, got, 1)
// The original casing of the first entry is preserved.
require.Equal(t, "User@Example.com", got[0])
}
func TestCollectBalanceNotifyRecipients_SkipsEmpty(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: " ", Verified: true},
{Email: "", Verified: true},
{Email: "valid@example.com", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"valid@example.com"}, got)
}
func TestCollectBalanceNotifyRecipients_TrimsWhitespace(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: " trimmed@example.com ", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"trimmed@example.com"}, got)
}

View File

@@ -363,7 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
}
func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService()
@@ -719,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
}
// ---------------------------------------------------------------------------
// GetModelPricingWithChannel
// ---------------------------------------------------------------------------
func TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", nil)
require.NoError(t, err)
require.NotNil(t, pricing)
// Should be identical to GetModelPricing
original, err := svc.GetModelPricing("claude-sonnet-4")
require.NoError(t, err)
require.InDelta(t, original.InputPricePerToken, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, original.OutputPricePerToken, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, original.CacheCreationPricePerToken, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, original.CacheReadPricePerToken, pricing.CacheReadPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideInputPriceOnly(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(99e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// InputPrice overridden (both normal and priority)
require.InDelta(t, 99e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 99e-6, pricing.InputPricePerTokenPriority, 1e-12)
// OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6)
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideOutputPriceOnly(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
OutputPrice: testPtrFloat64(88e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// OutputPrice overridden
require.InDelta(t, 88e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 88e-6, pricing.OutputPricePerTokenPriority, 1e-12)
// InputPrice unchanged (claude-sonnet-4 fallback = 3e-6)
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideAllFields(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(10e-6),
OutputPrice: testPtrFloat64(20e-6),
CacheWritePrice: testPtrFloat64(5e-6),
CacheReadPrice: testPtrFloat64(1e-6),
ImageOutputPrice: testPtrFloat64(50e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
require.InDelta(t, 10e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, pricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 20e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 20e-6, pricing.OutputPricePerTokenPriority, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreation5mPrice, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
require.InDelta(t, 1e-6, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 1e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
require.InDelta(t, 50e-6, pricing.ImageOutputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
CacheWritePrice: testPtrFloat64(7e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h
require.InDelta(t, 7e-6, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, 7e-6, pricing.CacheCreation5mPrice, 1e-12)
require.InDelta(t, 7e-6, pricing.CacheCreation1hPrice, 1e-12)
}
func TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
CacheReadPrice: testPtrFloat64(2e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// CacheReadPrice should set both normal and priority
require.InDelta(t, 2e-6, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 2e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
}
func TestGetModelPricingWithChannel_UnknownModelReturnsError(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(1e-6),
}
pricing, err := svc.GetModelPricingWithChannel("totally-unknown-model", chPricing)
require.Error(t, err)
require.Nil(t, pricing)
require.Contains(t, err.Error(), "pricing not found")
}

View File

@@ -0,0 +1,258 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// CalculateCostUnified
// ---------------------------------------------------------------------------
func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
input := CostInput{
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: nil, // no resolver
}
cost, err := svc.CalculateCostUnified(input)
require.NoError(t, err)
// Should match the old-path result exactly
expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil)
require.NoError(t, err)
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
require.Empty(t, cost.BillingMode)
}
func TestCalculateCostUnified_TokenMode(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
input := CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.5,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
expectedTotal := 1000*3e-6 + 500*15e-6
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModeToken), cost.BillingMode)
}
func TestCalculateCostUnified_PerRequestMode(t *testing.T) {
// Set up a ChannelService with a per-request pricing channel
cs := newTestChannelServiceWithCache(t, &channelCache{
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
{groupID: 1, model: "claude-sonnet-4"}: {
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.05),
},
},
channelByGroupID: map[int64]*Channel{
1: {ID: 1, Status: StatusActive},
},
groupPlatform: map[int64]string{1: ""},
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
mappingByGroupModel: map[channelModelKey]string{},
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
byID: map[int64]*Channel{},
})
bs := newTestBillingService()
resolver := NewModelPricingResolver(cs, bs)
groupID := int64(1)
input := CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
GroupID: &groupID,
Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50},
RequestCount: 3,
RateMultiplier: 2.0,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// 3 requests * $0.05 = $0.15
require.InDelta(t, 0.15, cost.TotalCost, 1e-10)
// ActualCost = 0.15 * 2.0 = 0.30
require.InDelta(t, 0.30, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
}
func TestCalculateCostUnified_ImageMode(t *testing.T) {
cs := newTestChannelServiceWithCache(t, &channelCache{
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
{groupID: 2, model: "gemini-image"}: {
BillingMode: BillingModeImage,
PerRequestPrice: testPtrFloat64(0.10),
},
},
channelByGroupID: map[int64]*Channel{
2: {ID: 2, Status: StatusActive},
},
groupPlatform: map[int64]string{2: ""},
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
mappingByGroupModel: map[channelModelKey]string{},
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
byID: map[int64]*Channel{},
})
bs := &BillingService{
cfg: &config.Config{},
fallbackPrices: map[string]*ModelPricing{},
}
resolver := NewModelPricingResolver(cs, bs)
groupID := int64(2)
input := CostInput{
Ctx: context.Background(),
Model: "gemini-image",
GroupID: &groupID,
Tokens: UsageTokens{},
RequestCount: 2,
RateMultiplier: 1.0,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// 2 * $0.10 = $0.20
require.InDelta(t, 0.20, cost.TotalCost, 1e-10)
require.InDelta(t, 0.20, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModeImage), cost.BillingMode)
}
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
costZero, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 0, // should default to 1.0
Resolver: resolver,
})
require.NoError(t, err)
costOne, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
}
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000}
costNeg, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: -5.0,
Resolver: resolver,
})
require.NoError(t, err)
costOne, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
}
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: UsageTokens{InputTokens: 100},
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.Equal(t, "token", cost.BillingMode)
}
func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
preResolved := &ResolvedPricing{
Mode: BillingModePerRequest,
DefaultPerRequestPrice: 0.07,
}
cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: UsageTokens{InputTokens: 100},
RequestCount: 2,
RateMultiplier: 1.0,
Resolver: resolver,
Resolved: preResolved,
})
require.NoError(t, err)
require.NotNil(t, cost)
// 2 * $0.07 = $0.14
require.InDelta(t, 0.14, cost.TotalCost, 1e-10)
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
}
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
// cache snapshot, bypassing the repository layer entirely.
func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService {
t.Helper()
cs := &ChannelService{}
cache.loadedAt = time.Now()
cs.cache.Store(cache)
return cs
}

View File

@@ -37,8 +37,10 @@ type Channel struct {
Name string
Description string
Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
Features string // 渠道特性描述JSON 数组),用于支付页面展示
FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation
CreatedAt time.Time
UpdatedAt time.Time
@@ -48,6 +50,25 @@ type Channel struct {
ModelPricing []ChannelModelPricing
// 渠道级模型映射按平台分组platform → {src→dst}
ModelMapping map[string]map[string]string
// 账号统计定价
ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计
AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
}
// AccountStatsPricingRule 账号统计定价规则
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
// 多条规则按 SortOrder 排序,先命中为准。
type AccountStatsPricingRule struct {
ID int64
ChannelID int64
Name string
GroupIDs []int64
AccountIDs []int64
SortOrder int
Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构)
CreatedAt time.Time
UpdatedAt time.Time
}
// ChannelModelPricing 渠道模型定价条目
@@ -176,9 +197,58 @@ func (c *Channel) Clone() *Channel {
cp.ModelMapping[platform] = inner
}
}
if c.FeaturesConfig != nil {
cp.FeaturesConfig = deepCopyFeaturesConfig(c.FeaturesConfig)
}
if c.AccountStatsPricingRules != nil {
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
for i, rule := range c.AccountStatsPricingRules {
cp.AccountStatsPricingRules[i] = rule
if rule.GroupIDs != nil {
cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs))
copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs)
}
if rule.AccountIDs != nil {
cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs))
copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs)
}
if rule.Pricing != nil {
cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing))
for j := range rule.Pricing {
cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone()
}
}
}
}
return &cp
}
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
if c == nil || c.FeaturesConfig == nil {
return false
}
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
if !ok {
return false
}
enabled, ok := wse[platform].(bool)
return ok && enabled
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
dst := make(map[string]any, len(src))
for k, v := range src {
if inner, ok := v.(map[string]any); ok {
dst[k] = deepCopyFeaturesConfig(inner)
} else {
dst[k] = v
}
}
return dst
}
// ValidateIntervals 校验区间列表的合法性。
// 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens
// 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义);

View File

@@ -81,9 +81,9 @@ type wildcardMappingEntry struct {
type channelCache struct {
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(按配置顺序,先匹配先使用
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(按配置顺序,先匹配先使用
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
@@ -315,6 +315,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
expandMappingToCache(cache, ch, gid, platform)
}
}
return cache
}
@@ -415,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
return ch.Clone(), nil
}
// GetGroupPlatform 获取分组的平台标识(从缓存)
func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string {
cache, err := s.loadCache(ctx)
if err != nil {
return ""
}
return cache.groupPlatform[groupID]
}
// channelLookup 热路径公共查找结果
type channelLookup struct {
cache *channelCache
@@ -556,15 +566,21 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
if err := validatePricingEntries(pricing); err != nil {
return err
}
return validateNoConflictingMappings(mapping)
}
// validatePricingEntries 校验定价条目(冲突检测 + 区间校验 + 计费模式校验),
// 同时用于主渠道定价和 account_stats_pricing_rules 的内部定价。
func validatePricingEntries(pricing []ChannelModelPricing) error {
if err := validateNoConflictingModels(pricing); err != nil {
return err
}
if err := validatePricingIntervals(pricing); err != nil {
return err
}
if err := validateNoConflictingMappings(mapping); err != nil {
return err
}
return validatePricingBillingMode(pricing)
}
@@ -655,14 +671,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
}
channel := &Channel{
Name: input.Name,
Description: input.Description,
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
Name: input.Name,
Description: input.Description,
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
Features: input.Features,
FeaturesConfig: input.FeaturesConfig,
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
AccountStatsPricingRules: input.AccountStatsPricingRules,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
@@ -671,6 +691,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
for i, rule := range channel.AccountStatsPricingRules {
if err := validatePricingEntries(rule.Pricing); err != nil {
return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err)
}
}
if err := s.repo.Create(ctx, channel); err != nil {
return nil, fmt.Errorf("create channel: %w", err)
@@ -699,6 +724,11 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
for i, rule := range channel.AccountStatsPricingRules {
if err := validatePricingEntries(rule.Pricing); err != nil {
return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err)
}
}
oldGroupIDs := s.getOldGroupIDs(ctx, id)
@@ -733,6 +763,9 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
if input.Features != nil {
channel.Features = *input.Features
}
if input.GroupIDs != nil {
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
return err
@@ -748,6 +781,15 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if input.FeaturesConfig != nil {
channel.FeaturesConfig = input.FeaturesConfig
}
if input.ApplyPricingToAccountStats != nil {
channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats
}
if input.AccountStatsPricingRules != nil {
channel.AccountStatsPricingRules = *input.AccountStatsPricingRules
}
return nil
}
@@ -913,23 +955,31 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
// CreateChannelInput 创建渠道输入
type CreateChannelInput struct {
Name string
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels bool
Name string
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels bool
Features string
FeaturesConfig map[string]any
ApplyPricingToAccountStats bool
AccountStatsPricingRules []AccountStatsPricingRule
}
// UpdateChannelInput 更新渠道输入
type UpdateChannelInput struct {
Name string
Description *string
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels *bool
Name string
Description *string
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels *bool
Features *string
FeaturesConfig map[string]any
ApplyPricingToAccountStats *bool
AccountStatsPricingRules *[]AccountStatsPricingRule
}

View File

@@ -0,0 +1,62 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
},
}
require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("openai"))
}
func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
c := &Channel{FeaturesConfig: nil}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
var c *Channel
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: true, // not a map
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}

View File

@@ -343,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
// Uses a detached context with timeout to prevent HTTP request cancellation from
// causing the entire batch to fail (which would show all concurrency as 0).
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
@@ -356,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
}
return result, nil
}
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
// Use a detached context so that a cancelled HTTP request doesn't cause
// the Redis pipeline to fail and return all-zero concurrency counts.
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
}

View File

@@ -249,6 +249,18 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false
SettingKeyEnableCCHSigning = "enable_cch_signing"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值USD
SettingKeyBalanceLowNotifyRechargeURL = "balance_low_notify_recharge_url" // 充值页面 URL
// Account Quota Notification
SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表JSON 数组)
// Web Search Emulation
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -7,8 +7,9 @@ import (
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"log/slog"
"math/big"
"net"
"net/smtp"
"net/url"
"strconv"
@@ -34,6 +35,11 @@ type EmailCache interface {
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
// Notify email verification code methods
GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteNotifyVerifyCode(ctx context.Context, email string) error
// Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
@@ -43,6 +49,10 @@ type EmailCache interface {
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
// Notify code rate limiting per user
IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error)
GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error)
}
// VerificationCodeData represents verification code data
@@ -50,6 +60,7 @@ type VerificationCodeData struct {
Code string
Attempts int
CreatedAt time.Time
ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts
}
// PasswordResetTokenData represents password reset token data
@@ -146,11 +157,18 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
return s.SendEmailWithConfig(config, to, subject, body)
}
const smtpDialTimeout = 10 * time.Second
const smtpIOTimeout = 20 * time.Second
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From
// Sanitize all SMTP header fields to prevent header injection (CR/LF removal).
to = sanitizeEmailHeader(to)
subject = sanitizeEmailHeader(subject)
from := sanitizeEmailHeader(config.From)
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
from = fmt.Sprintf("%s <%s>", sanitizeEmailHeader(config.FromName), sanitizeEmailHeader(config.From))
}
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s",
@@ -163,7 +181,54 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
}
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg))
return s.sendMailPlain(addr, auth, config.From, to, []byte(msg), config.Host)
}
// sendMailPlain sends mail without TLS using a dialer with timeout.
func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
dialer := &net.Dialer{Timeout: smtpDialTimeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return fmt.Errorf("smtp dial: %w", err)
}
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host)
if err != nil {
return fmt.Errorf("new smtp client: %w", err)
}
defer func() { _ = client.Close() }()
// Opportunistic STARTTLS: upgrade to encrypted connection if the server supports it.
// This mirrors the behavior of smtp.SendMail which we replaced for timeout support.
if ok, _ := client.Extension("STARTTLS"); ok {
if err = client.StartTLS(&tls.Config{ServerName: host, MinVersion: tls.VersionTLS12}); err != nil {
return fmt.Errorf("starttls: %w", err)
}
}
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err)
}
if err = client.Mail(from); err != nil {
return fmt.Errorf("smtp mail: %w", err)
}
if err = client.Rcpt(to); err != nil {
return fmt.Errorf("smtp rcpt: %w", err)
}
w, err := client.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
if _, err = w.Write(msg); err != nil {
return fmt.Errorf("write msg: %w", err)
}
if err = w.Close(); err != nil {
return fmt.Errorf("close writer: %w", err)
}
_ = client.Quit()
return nil
}
// sendMailTLS 使用TLS发送邮件
@@ -174,10 +239,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
MinVersion: tls.VersionTLS12,
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
dialer := &net.Dialer{Timeout: smtpDialTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("tls dial: %w", err)
}
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host)
@@ -254,6 +321,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
}
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
@@ -286,8 +354,12 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update verification attempt count", "email", email, "error", err)
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
@@ -297,7 +369,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证成功,删除验证码
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err)
slog.Error("failed to delete verification code after success", "email", email, "error", err)
}
return nil
}
@@ -447,7 +519,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
slog.Info("password reset email skipped due to cooldown", "email", email)
return nil // Silent success to prevent revealing cooldown to attackers
}
@@ -458,7 +530,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
// Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
slog.Error("failed to set password reset cooldown", "email", email, "error", err)
}
return nil
@@ -488,7 +560,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
// Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
slog.Error("failed to delete password reset token after consumption", "email", email, "error", err)
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More