diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml
index 6f76ef4f..d7e15377 100644
--- a/.github/workflows/backend-ci.yml
+++ b/.github/workflows/backend-ci.yml
@@ -17,6 +17,7 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
+ cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.2'
@@ -36,6 +37,7 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
+ cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.2'
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 47f8f518..64709b5b 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -36,19 +36,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// Business layer ProviderSets
repository.ProviderSet,
service.ProviderSet,
+ payment.ProviderSet,
middleware.ProviderSet,
handler.ProviderSet,
// Server layer ProviderSet
server.ProviderSet,
- // Payment providers
- payment.ProvideRegistry,
- payment.ProvideEncryptionKey,
- payment.ProvideDefaultLoadBalancer,
- service.ProvidePaymentConfigService,
- service.ProvidePaymentOrderExpiryService,
-
// Privacy client factory for OpenAI training opt-out
providePrivacyClientFactory,
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index c288a289..1d39fa1e 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -50,8 +50,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
- channelRepository := repository.NewChannelRepository(db)
- settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
+ proxyRepository := repository.NewProxyRepository(client, db)
+ settingService := service.ProvideSettingService(settingRepository, groupRepository, proxyRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
@@ -65,23 +65,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
- apiKeyService.SetRateLimitCacheInvalidator(billingCache)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
- userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
+ userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
- registry := payment.ProvideRegistry()
- encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
- if err != nil {
- return nil, err
- }
- defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
- paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
- paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
- paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
@@ -89,10 +79,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
- userHandler := handler.NewUserHandler(userService)
+ userHandler := handler.NewUserHandler(userService, emailService, emailCache)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
- usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -112,7 +101,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
- proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
@@ -120,11 +108,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
+ sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
+ rpmCache := repository.NewRPMCache(redisClient)
+ groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
+ groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
- openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
@@ -134,7 +125,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
- oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
@@ -142,23 +132,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient)
- geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
- gatewayCache := repository.NewGatewayCache(redisClient)
- schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
- schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
- antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
- internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
+ oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
+ geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
+ gatewayCache := repository.NewGatewayCache(redisClient)
+ schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
+ schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
+ antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
+ internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
- sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
- rpmCache := repository.NewRPMCache(redisClient)
- groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
- groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
@@ -175,6 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
+ usageBillingRepository := repository.NewUsageBillingRepository(client, db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -183,17 +171,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
- claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
+ claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
+ channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
- gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
- openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
- openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
+ balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
+ gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
+ openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
+ openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
- settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@@ -221,8 +210,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
- adminPaymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, adminPaymentHandler)
+ registry := payment.ProvideRegistry()
+ encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
+ if err != nil {
+ return nil, err
+ }
+ defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
+ paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
+ paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
+ settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
+ paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
+ paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -245,7 +244,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
- tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
+ tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index e947b2e8..68bdbf55 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -616,6 +616,7 @@ var (
{Name: "sort_order", Type: field.TypeInt, Default: 0},
{Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "refund_enabled", Type: field.TypeBool, Default: false},
+ {Name: "allow_user_refund", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
}
@@ -1078,6 +1079,11 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
+ {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
+ {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
+ {Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 6b2fa838..524ccb92 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -15642,25 +15642,26 @@ func (m *PaymentOrderMutation) ResetEdge(name string) error {
// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph.
type PaymentProviderInstanceMutation struct {
config
- op Op
- typ string
- id *int64
- provider_key *string
- name *string
- _config *string
- supported_types *string
- enabled *bool
- payment_mode *string
- sort_order *int
- addsort_order *int
- limits *string
- refund_enabled *bool
- created_at *time.Time
- updated_at *time.Time
- clearedFields map[string]struct{}
- done bool
- oldValue func(context.Context) (*PaymentProviderInstance, error)
- predicates []predicate.PaymentProviderInstance
+ op Op
+ typ string
+ id *int64
+ provider_key *string
+ name *string
+ _config *string
+ supported_types *string
+ enabled *bool
+ payment_mode *string
+ sort_order *int
+ addsort_order *int
+ limits *string
+ refund_enabled *bool
+ allow_user_refund *bool
+ created_at *time.Time
+ updated_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*PaymentProviderInstance, error)
+ predicates []predicate.PaymentProviderInstance
}
var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil)
@@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() {
m.refund_enabled = nil
}
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) {
+ m.allow_user_refund = &b
+}
+
+// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation.
+func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) {
+ v := m.allow_user_refund
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAllowUserRefund requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err)
+ }
+ return oldValue.AllowUserRefund, nil
+}
+
+// ResetAllowUserRefund resets all changes to the "allow_user_refund" field.
+func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() {
+ m.allow_user_refund = nil
+}
+
// SetCreatedAt sets the "created_at" field.
func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
m.created_at = &t
@@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *PaymentProviderInstanceMutation) Fields() []string {
- fields := make([]string, 0, 11)
+ fields := make([]string, 0, 12)
if m.provider_key != nil {
fields = append(fields, paymentproviderinstance.FieldProviderKey)
}
@@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string {
if m.refund_enabled != nil {
fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
}
+ if m.allow_user_refund != nil {
+ fields = append(fields, paymentproviderinstance.FieldAllowUserRefund)
+ }
if m.created_at != nil {
fields = append(fields, paymentproviderinstance.FieldCreatedAt)
}
@@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
return m.Limits()
case paymentproviderinstance.FieldRefundEnabled:
return m.RefundEnabled()
+ case paymentproviderinstance.FieldAllowUserRefund:
+ return m.AllowUserRefund()
case paymentproviderinstance.FieldCreatedAt:
return m.CreatedAt()
case paymentproviderinstance.FieldUpdatedAt:
@@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str
return m.OldLimits(ctx)
case paymentproviderinstance.FieldRefundEnabled:
return m.OldRefundEnabled(ctx)
+ case paymentproviderinstance.FieldAllowUserRefund:
+ return m.OldAllowUserRefund(ctx)
case paymentproviderinstance.FieldCreatedAt:
return m.OldCreatedAt(ctx)
case paymentproviderinstance.FieldUpdatedAt:
@@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value)
}
m.SetRefundEnabled(v)
return nil
+ case paymentproviderinstance.FieldAllowUserRefund:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAllowUserRefund(v)
+ return nil
case paymentproviderinstance.FieldCreatedAt:
v, ok := value.(time.Time)
if !ok {
@@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error {
case paymentproviderinstance.FieldRefundEnabled:
m.ResetRefundEnabled()
return nil
+ case paymentproviderinstance.FieldAllowUserRefund:
+ m.ResetAllowUserRefund()
+ return nil
case paymentproviderinstance.FieldCreatedAt:
m.ResetCreatedAt()
return nil
@@ -28210,6 +28264,13 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
+ balance_notify_enabled *bool
+ balance_notify_threshold_type *string
+ balance_notify_threshold *float64
+ addbalance_notify_threshold *float64
+ balance_notify_extra_emails *string
+ total_recharged *float64
+ addtotal_recharged *float64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -28927,6 +28988,240 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
+ m.balance_notify_enabled = &b
+}
+
+// BalanceNotifyEnabled returns the value of the "balance_notify_enabled" field in the mutation.
+func (m *UserMutation) BalanceNotifyEnabled() (r bool, exists bool) {
+ v := m.balance_notify_enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyEnabled returns the old "balance_notify_enabled" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyEnabled: %w", err)
+ }
+ return oldValue.BalanceNotifyEnabled, nil
+}
+
+// ResetBalanceNotifyEnabled resets all changes to the "balance_notify_enabled" field.
+func (m *UserMutation) ResetBalanceNotifyEnabled() {
+ m.balance_notify_enabled = nil
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (m *UserMutation) SetBalanceNotifyThresholdType(s string) {
+ m.balance_notify_threshold_type = &s
+}
+
+// BalanceNotifyThresholdType returns the value of the "balance_notify_threshold_type" field in the mutation.
+func (m *UserMutation) BalanceNotifyThresholdType() (r string, exists bool) {
+ v := m.balance_notify_threshold_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyThresholdType returns the old "balance_notify_threshold_type" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyThresholdType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyThresholdType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyThresholdType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyThresholdType: %w", err)
+ }
+ return oldValue.BalanceNotifyThresholdType, nil
+}
+
+// ResetBalanceNotifyThresholdType resets all changes to the "balance_notify_threshold_type" field.
+func (m *UserMutation) ResetBalanceNotifyThresholdType() {
+ m.balance_notify_threshold_type = nil
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (m *UserMutation) SetBalanceNotifyThreshold(f float64) {
+ m.balance_notify_threshold = &f
+ m.addbalance_notify_threshold = nil
+}
+
+// BalanceNotifyThreshold returns the value of the "balance_notify_threshold" field in the mutation.
+func (m *UserMutation) BalanceNotifyThreshold() (r float64, exists bool) {
+ v := m.balance_notify_threshold
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyThreshold returns the old "balance_notify_threshold" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyThreshold(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyThreshold is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyThreshold requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyThreshold: %w", err)
+ }
+ return oldValue.BalanceNotifyThreshold, nil
+}
+
+// AddBalanceNotifyThreshold adds f to the "balance_notify_threshold" field.
+func (m *UserMutation) AddBalanceNotifyThreshold(f float64) {
+ if m.addbalance_notify_threshold != nil {
+ *m.addbalance_notify_threshold += f
+ } else {
+ m.addbalance_notify_threshold = &f
+ }
+}
+
+// AddedBalanceNotifyThreshold returns the value that was added to the "balance_notify_threshold" field in this mutation.
+func (m *UserMutation) AddedBalanceNotifyThreshold() (r float64, exists bool) {
+ v := m.addbalance_notify_threshold
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (m *UserMutation) ClearBalanceNotifyThreshold() {
+ m.balance_notify_threshold = nil
+ m.addbalance_notify_threshold = nil
+ m.clearedFields[user.FieldBalanceNotifyThreshold] = struct{}{}
+}
+
+// BalanceNotifyThresholdCleared returns if the "balance_notify_threshold" field was cleared in this mutation.
+func (m *UserMutation) BalanceNotifyThresholdCleared() bool {
+ _, ok := m.clearedFields[user.FieldBalanceNotifyThreshold]
+ return ok
+}
+
+// ResetBalanceNotifyThreshold resets all changes to the "balance_notify_threshold" field.
+func (m *UserMutation) ResetBalanceNotifyThreshold() {
+ m.balance_notify_threshold = nil
+ m.addbalance_notify_threshold = nil
+ delete(m.clearedFields, user.FieldBalanceNotifyThreshold)
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (m *UserMutation) SetBalanceNotifyExtraEmails(s string) {
+ m.balance_notify_extra_emails = &s
+}
+
+// BalanceNotifyExtraEmails returns the value of the "balance_notify_extra_emails" field in the mutation.
+func (m *UserMutation) BalanceNotifyExtraEmails() (r string, exists bool) {
+ v := m.balance_notify_extra_emails
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyExtraEmails returns the old "balance_notify_extra_emails" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyExtraEmails(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyExtraEmails is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyExtraEmails requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyExtraEmails: %w", err)
+ }
+ return oldValue.BalanceNotifyExtraEmails, nil
+}
+
+// ResetBalanceNotifyExtraEmails resets all changes to the "balance_notify_extra_emails" field.
+func (m *UserMutation) ResetBalanceNotifyExtraEmails() {
+ m.balance_notify_extra_emails = nil
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (m *UserMutation) SetTotalRecharged(f float64) {
+ m.total_recharged = &f
+ m.addtotal_recharged = nil
+}
+
+// TotalRecharged returns the value of the "total_recharged" field in the mutation.
+func (m *UserMutation) TotalRecharged() (r float64, exists bool) {
+ v := m.total_recharged
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotalRecharged returns the old "total_recharged" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotalRecharged(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotalRecharged is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotalRecharged requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotalRecharged: %w", err)
+ }
+ return oldValue.TotalRecharged, nil
+}
+
+// AddTotalRecharged adds f to the "total_recharged" field.
+func (m *UserMutation) AddTotalRecharged(f float64) {
+ if m.addtotal_recharged != nil {
+ *m.addtotal_recharged += f
+ } else {
+ m.addtotal_recharged = &f
+ }
+}
+
+// AddedTotalRecharged returns the value that was added to the "total_recharged" field in this mutation.
+func (m *UserMutation) AddedTotalRecharged() (r float64, exists bool) {
+ v := m.addtotal_recharged
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetTotalRecharged resets all changes to the "total_recharged" field.
+func (m *UserMutation) ResetTotalRecharged() {
+ m.total_recharged = nil
+ m.addtotal_recharged = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -29501,7 +29796,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 14)
+ fields := make([]string, 0, 19)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29544,6 +29839,21 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.balance_notify_enabled != nil {
+ fields = append(fields, user.FieldBalanceNotifyEnabled)
+ }
+ if m.balance_notify_threshold_type != nil {
+ fields = append(fields, user.FieldBalanceNotifyThresholdType)
+ }
+ if m.balance_notify_threshold != nil {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
+ }
+ if m.balance_notify_extra_emails != nil {
+ fields = append(fields, user.FieldBalanceNotifyExtraEmails)
+ }
+ if m.total_recharged != nil {
+ fields = append(fields, user.FieldTotalRecharged)
+ }
return fields
}
@@ -29580,6 +29890,16 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
+ case user.FieldBalanceNotifyEnabled:
+ return m.BalanceNotifyEnabled()
+ case user.FieldBalanceNotifyThresholdType:
+ return m.BalanceNotifyThresholdType()
+ case user.FieldBalanceNotifyThreshold:
+ return m.BalanceNotifyThreshold()
+ case user.FieldBalanceNotifyExtraEmails:
+ return m.BalanceNotifyExtraEmails()
+ case user.FieldTotalRecharged:
+ return m.TotalRecharged()
}
return nil, false
}
@@ -29617,6 +29937,16 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
+ case user.FieldBalanceNotifyEnabled:
+ return m.OldBalanceNotifyEnabled(ctx)
+ case user.FieldBalanceNotifyThresholdType:
+ return m.OldBalanceNotifyThresholdType(ctx)
+ case user.FieldBalanceNotifyThreshold:
+ return m.OldBalanceNotifyThreshold(ctx)
+ case user.FieldBalanceNotifyExtraEmails:
+ return m.OldBalanceNotifyExtraEmails(ctx)
+ case user.FieldTotalRecharged:
+ return m.OldTotalRecharged(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -29724,6 +30054,41 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
+ case user.FieldBalanceNotifyEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyEnabled(v)
+ return nil
+ case user.FieldBalanceNotifyThresholdType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyThresholdType(v)
+ return nil
+ case user.FieldBalanceNotifyThreshold:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyThreshold(v)
+ return nil
+ case user.FieldBalanceNotifyExtraEmails:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyExtraEmails(v)
+ return nil
+ case user.FieldTotalRecharged:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotalRecharged(v)
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -29738,6 +30103,12 @@ func (m *UserMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, user.FieldConcurrency)
}
+ if m.addbalance_notify_threshold != nil {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
+ }
+ if m.addtotal_recharged != nil {
+ fields = append(fields, user.FieldTotalRecharged)
+ }
return fields
}
@@ -29750,6 +30121,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalance()
case user.FieldConcurrency:
return m.AddedConcurrency()
+ case user.FieldBalanceNotifyThreshold:
+ return m.AddedBalanceNotifyThreshold()
+ case user.FieldTotalRecharged:
+ return m.AddedTotalRecharged()
}
return nil, false
}
@@ -29773,6 +30148,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
+ case user.FieldBalanceNotifyThreshold:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddBalanceNotifyThreshold(v)
+ return nil
+ case user.FieldTotalRecharged:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddTotalRecharged(v)
+ return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
@@ -29790,6 +30179,9 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
+ }
return fields
}
@@ -29813,6 +30205,9 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
+ case user.FieldBalanceNotifyThreshold:
+ m.ClearBalanceNotifyThreshold()
+ return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -29863,6 +30258,21 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
+ case user.FieldBalanceNotifyEnabled:
+ m.ResetBalanceNotifyEnabled()
+ return nil
+ case user.FieldBalanceNotifyThresholdType:
+ m.ResetBalanceNotifyThresholdType()
+ return nil
+ case user.FieldBalanceNotifyThreshold:
+ m.ResetBalanceNotifyThreshold()
+ return nil
+ case user.FieldBalanceNotifyExtraEmails:
+ m.ResetBalanceNotifyExtraEmails()
+ return nil
+ case user.FieldTotalRecharged:
+ m.ResetTotalRecharged()
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
diff --git a/backend/ent/paymentproviderinstance.go b/backend/ent/paymentproviderinstance.go
index 087cb13a..4279b86e 100644
--- a/backend/ent/paymentproviderinstance.go
+++ b/backend/ent/paymentproviderinstance.go
@@ -35,6 +35,8 @@ type PaymentProviderInstance struct {
Limits string `json:"limits,omitempty"`
// RefundEnabled holds the value of the "refund_enabled" field.
RefundEnabled bool `json:"refund_enabled,omitempty"`
+ // AllowUserRefund holds the value of the "allow_user_refund" field.
+ AllowUserRefund bool `json:"allow_user_refund,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
@@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled:
+ case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund:
values[i] = new(sql.NullBool)
case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder:
values[i] = new(sql.NullInt64)
@@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any)
} else if value.Valid {
_m.RefundEnabled = value.Bool
}
+ case paymentproviderinstance.FieldAllowUserRefund:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i])
+ } else if value.Valid {
+ _m.AllowUserRefund = value.Bool
+ }
case paymentproviderinstance.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
@@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string {
builder.WriteString("refund_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled))
builder.WriteString(", ")
+ builder.WriteString("allow_user_refund=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund))
+ builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
diff --git a/backend/ent/paymentproviderinstance/paymentproviderinstance.go b/backend/ent/paymentproviderinstance/paymentproviderinstance.go
index c430fef6..eb1b0c52 100644
--- a/backend/ent/paymentproviderinstance/paymentproviderinstance.go
+++ b/backend/ent/paymentproviderinstance/paymentproviderinstance.go
@@ -31,6 +31,8 @@ const (
FieldLimits = "limits"
// FieldRefundEnabled holds the string denoting the refund_enabled field in the database.
FieldRefundEnabled = "refund_enabled"
+ // FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database.
+ FieldAllowUserRefund = "allow_user_refund"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
@@ -51,6 +53,7 @@ var Columns = []string{
FieldSortOrder,
FieldLimits,
FieldRefundEnabled,
+ FieldAllowUserRefund,
FieldCreatedAt,
FieldUpdatedAt,
}
@@ -88,6 +91,8 @@ var (
DefaultLimits string
// DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field.
DefaultRefundEnabled bool
+ // DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field.
+ DefaultAllowUserRefund bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
@@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc()
}
+// ByAllowUserRefund orders the results by the allow_user_refund field.
+func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc()
+}
+
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
diff --git a/backend/ent/paymentproviderinstance/where.go b/backend/ent/paymentproviderinstance/where.go
index 7b99517f..40e5a1f6 100644
--- a/backend/ent/paymentproviderinstance/where.go
+++ b/backend/ent/paymentproviderinstance/where.go
@@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v))
}
+// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ.
+func AllowUserRefund(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
+}
+
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
@@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v))
}
+// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field.
+func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
+}
+
+// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field.
+func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
diff --git a/backend/ent/paymentproviderinstance_create.go b/backend/ent/paymentproviderinstance_create.go
index 20b16ddd..d1b14617 100644
--- a/backend/ent/paymentproviderinstance_create.go
+++ b/backend/ent/paymentproviderinstance_create.go
@@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym
return _c
}
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate {
+ _c.mutation.SetAllowUserRefund(v)
+ return _c
+}
+
+// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetAllowUserRefund(*v)
+ }
+ return _c
+}
+
// SetCreatedAt sets the "created_at" field.
func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate {
_c.mutation.SetCreatedAt(v)
@@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() {
v := paymentproviderinstance.DefaultRefundEnabled
_c.mutation.SetRefundEnabled(v)
}
+ if _, ok := _c.mutation.AllowUserRefund(); !ok {
+ v := paymentproviderinstance.DefaultAllowUserRefund
+ _c.mutation.SetAllowUserRefund(v)
+ }
if _, ok := _c.mutation.CreatedAt(); !ok {
v := paymentproviderinstance.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
@@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error {
if _, ok := _c.mutation.RefundEnabled(); !ok {
return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)}
}
+ if _, ok := _c.mutation.AllowUserRefund(); !ok {
+ return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)}
+ }
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)}
}
@@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance,
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
_node.RefundEnabled = value
}
+ if value, ok := _c.mutation.AllowUserRefund(); ok {
+ _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
+ _node.AllowUserRefund = value
+ }
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
@@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn
return u
}
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldAllowUserRefund, v)
+ return u
+}
+
+// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund)
+ return u
+}
+
// SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert {
u.Set(paymentproviderinstance.FieldUpdatedAt, v)
@@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide
})
}
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetAllowUserRefund(v)
+ })
+}
+
+// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateAllowUserRefund()
+ })
+}
+
// SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
@@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid
})
}
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetAllowUserRefund(v)
+ })
+}
+
+// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateAllowUserRefund()
+ })
+}
+
// SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
diff --git a/backend/ent/paymentproviderinstance_update.go b/backend/ent/paymentproviderinstance_update.go
index 06dba527..6bb3a82d 100644
--- a/backend/ent/paymentproviderinstance_update.go
+++ b/backend/ent/paymentproviderinstance_update.go
@@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym
return _u
}
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetAllowUserRefund(v)
+ return _u
+}
+
+// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetAllowUserRefund(*v)
+ }
+ return _u
+}
+
// SetUpdatedAt sets the "updated_at" field.
func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate {
_u.mutation.SetUpdatedAt(v)
@@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int
if value, ok := _u.mutation.RefundEnabled(); ok {
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
}
+ if value, ok := _u.mutation.AllowUserRefund(); ok {
+ _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
+ }
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
}
@@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P
return _u
}
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetAllowUserRefund(v)
+ return _u
+}
+
+// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetAllowUserRefund(*v)
+ }
+ return _u
+}
+
// SetUpdatedAt sets the "updated_at" field.
func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne {
_u.mutation.SetUpdatedAt(v)
@@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node
if value, ok := _u.mutation.RefundEnabled(); ok {
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
}
+ if value, ok := _u.mutation.AllowUserRefund(); ok {
+ _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
+ }
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 821b7d66..fbdd08c7 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -668,12 +668,16 @@ func init() {
paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor()
// paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field.
paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool)
+ // paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field.
+ paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor()
+ // paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field.
+ paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool)
// paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field.
- paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor()
+ paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor()
// paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time)
// paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field.
- paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor()
+ paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor()
// paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
@@ -1293,6 +1297,22 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
+ // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
+ userDescBalanceNotifyEnabled := userFields[11].Descriptor()
+ // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
+ user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
+ // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
+ userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
+ // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
+ user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
+ // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
+ userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
+ // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
+ user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
+ // userDescTotalRecharged is the schema descriptor for total_recharged field.
+ userDescTotalRecharged := userFields[15].Descriptor()
+ // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
+ user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/payment_provider_instance.go b/backend/ent/schema/payment_provider_instance.go
index 08ab7d31..e4c0b72c 100644
--- a/backend/ent/schema/payment_provider_instance.go
+++ b/backend/ent/schema/payment_provider_instance.go
@@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field {
Default(""),
field.Bool("refund_enabled").
Default(false),
+ field.Bool("allow_user_refund").
+ Default(false),
field.Time("created_at").
Immutable().
Default(time.Now).
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index af143d38..ef52e985 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -72,6 +72,22 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+
+ // 余额不足通知
+ field.Bool("balance_notify_enabled").
+ Default(true),
+ field.String("balance_notify_threshold_type").
+ Default("fixed"), // "fixed" | "percentage"
+ field.Float("balance_notify_threshold").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
+ Optional().
+ Nillable(),
+ field.String("balance_notify_extra_emails").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Default("[]"),
+ field.Float("total_recharged").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
+ Default(0),
}
}
diff --git a/backend/ent/user.go b/backend/ent/user.go
index a0eef2ba..9fa91f74 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,6 +45,16 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
+ // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
+ // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type,omitempty"`
+ // BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field.
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
+ // BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field.
+ BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
+ // TotalRecharged holds the value of the "total_recharged" field.
+ TotalRecharged float64 `json:"total_recharged,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -184,13 +194,13 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case user.FieldTotpEnabled:
+ case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled:
values[i] = new(sql.NullBool)
- case user.FieldBalance:
+ case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
@@ -302,6 +312,37 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
+ case user.FieldBalanceNotifyEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyEnabled = value.Bool
+ }
+ case user.FieldBalanceNotifyThresholdType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_threshold_type", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyThresholdType = value.String
+ }
+ case user.FieldBalanceNotifyThreshold:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyThreshold = new(float64)
+ *_m.BalanceNotifyThreshold = value.Float64
+ }
+ case user.FieldBalanceNotifyExtraEmails:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_extra_emails", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyExtraEmails = value.String
+ }
+ case user.FieldTotalRecharged:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field total_recharged", values[i])
+ } else if value.Valid {
+ _m.TotalRecharged = value.Float64
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -440,6 +481,23 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
+ builder.WriteString(", ")
+ builder.WriteString("balance_notify_enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
+ builder.WriteString(", ")
+ builder.WriteString("balance_notify_threshold_type=")
+ builder.WriteString(_m.BalanceNotifyThresholdType)
+ builder.WriteString(", ")
+ if v := _m.BalanceNotifyThreshold; v != nil {
+ builder.WriteString("balance_notify_threshold=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("balance_notify_extra_emails=")
+ builder.WriteString(_m.BalanceNotifyExtraEmails)
+ builder.WriteString(", ")
+ builder.WriteString("total_recharged=")
+ builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index 338518a8..d88a3a38 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,6 +43,16 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
+ // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
+ FieldBalanceNotifyEnabled = "balance_notify_enabled"
+ // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
+ FieldBalanceNotifyThresholdType = "balance_notify_threshold_type"
+ // FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database.
+ FieldBalanceNotifyThreshold = "balance_notify_threshold"
+ // FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database.
+ FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
+ // FieldTotalRecharged holds the string denoting the total_recharged field in the database.
+ FieldTotalRecharged = "total_recharged"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -161,6 +171,11 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
+ FieldBalanceNotifyEnabled,
+ FieldBalanceNotifyThresholdType,
+ FieldBalanceNotifyThreshold,
+ FieldBalanceNotifyExtraEmails,
+ FieldTotalRecharged,
}
var (
@@ -217,6 +232,14 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
+ // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
+ DefaultBalanceNotifyEnabled bool
+ // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
+ DefaultBalanceNotifyThresholdType string
+ // DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field.
+ DefaultBalanceNotifyExtraEmails string
+ // DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
+ DefaultTotalRecharged float64
)
// OrderOption defines the ordering options for the User queries.
@@ -297,6 +320,31 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
+// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
+func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
+}
+
+// ByBalanceNotifyThresholdType orders the results by the balance_notify_threshold_type field.
+func ByBalanceNotifyThresholdType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyThresholdType, opts...).ToFunc()
+}
+
+// ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field.
+func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc()
+}
+
+// ByBalanceNotifyExtraEmails orders the results by the balance_notify_extra_emails field.
+func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc()
+}
+
+// ByTotalRecharged orders the results by the total_recharged field.
+func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index b1d1000f..2788aa7a 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,6 +125,31 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
+// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
+func BalanceNotifyEnabled(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyThresholdType applies equality check predicate on the "balance_notify_threshold_type" field. It's identical to BalanceNotifyThresholdTypeEQ.
+func BalanceNotifyThresholdType(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ.
+func BalanceNotifyThreshold(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyExtraEmails applies equality check predicate on the "balance_notify_extra_emails" field. It's identical to BalanceNotifyExtraEmailsEQ.
+func BalanceNotifyExtraEmails(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
+// TotalRecharged applies equality check predicate on the "total_recharged" field. It's identical to TotalRechargedEQ.
+func TotalRecharged(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -860,6 +885,236 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
+// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
+func BalanceNotifyEnabledEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyEnabledNEQ applies the NEQ predicate on the "balance_notify_enabled" field.
+func BalanceNotifyEnabledNEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyThresholdTypeEQ applies the EQ predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeNEQ applies the NEQ predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeIn applies the In predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyThresholdType, vs...))
+}
+
+// BalanceNotifyThresholdTypeNotIn applies the NotIn predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThresholdType, vs...))
+}
+
+// BalanceNotifyThresholdTypeGT applies the GT predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeGTE applies the GTE predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeLT applies the LT predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeLTE applies the LTE predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeContains applies the Contains predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeHasPrefix applies the HasPrefix predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeHasSuffix applies the HasSuffix predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeEqualFold applies the EqualFold predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeContainsFold applies the ContainsFold predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdNEQ applies the NEQ predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdIn applies the In predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyThreshold, vs...))
+}
+
+// BalanceNotifyThresholdNotIn applies the NotIn predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNotIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThreshold, vs...))
+}
+
+// BalanceNotifyThresholdGT applies the GT predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdGT(v float64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdGTE applies the GTE predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdGTE(v float64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdLT applies the LT predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdLT(v float64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdLTE applies the LTE predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdLTE(v float64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdIsNil applies the IsNil predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldBalanceNotifyThreshold))
+}
+
+// BalanceNotifyThresholdNotNil applies the NotNil predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldBalanceNotifyThreshold))
+}
+
+// BalanceNotifyExtraEmailsEQ applies the EQ predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsNEQ applies the NEQ predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsIn applies the In predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyExtraEmails, vs...))
+}
+
+// BalanceNotifyExtraEmailsNotIn applies the NotIn predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyExtraEmails, vs...))
+}
+
+// BalanceNotifyExtraEmailsGT applies the GT predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsGTE applies the GTE predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsLT applies the LT predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsLTE applies the LTE predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsContains applies the Contains predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsHasPrefix applies the HasPrefix predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsHasSuffix applies the HasSuffix predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsEqualFold applies the EqualFold predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsContainsFold applies the ContainsFold predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v))
+}
+
+// TotalRechargedEQ applies the EQ predicate on the "total_recharged" field.
+func TotalRechargedEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
+}
+
+// TotalRechargedNEQ applies the NEQ predicate on the "total_recharged" field.
+func TotalRechargedNEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotalRecharged, v))
+}
+
+// TotalRechargedIn applies the In predicate on the "total_recharged" field.
+func TotalRechargedIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldTotalRecharged, vs...))
+}
+
+// TotalRechargedNotIn applies the NotIn predicate on the "total_recharged" field.
+func TotalRechargedNotIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldTotalRecharged, vs...))
+}
+
+// TotalRechargedGT applies the GT predicate on the "total_recharged" field.
+func TotalRechargedGT(v float64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldTotalRecharged, v))
+}
+
+// TotalRechargedGTE applies the GTE predicate on the "total_recharged" field.
+func TotalRechargedGTE(v float64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldTotalRecharged, v))
+}
+
+// TotalRechargedLT applies the LT predicate on the "total_recharged" field.
+func TotalRechargedLT(v float64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldTotalRecharged, v))
+}
+
+// TotalRechargedLTE applies the LTE predicate on the "total_recharged" field.
+func TotalRechargedLTE(v float64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index 7f1c5df1..fbc64f9c 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -211,6 +211,76 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
+ _c.mutation.SetBalanceNotifyEnabled(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyEnabled(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_c *UserCreate) SetBalanceNotifyThresholdType(v string) *UserCreate {
+ _c.mutation.SetBalanceNotifyThresholdType(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyThresholdType(v *string) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyThresholdType(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate {
+ _c.mutation.SetBalanceNotifyThreshold(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyThreshold(v *float64) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyThreshold(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_c *UserCreate) SetBalanceNotifyExtraEmails(v string) *UserCreate {
+ _c.mutation.SetBalanceNotifyExtraEmails(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _c
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (_c *UserCreate) SetTotalRecharged(v float64) *UserCreate {
+ _c.mutation.SetTotalRecharged(v)
+ return _c
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
+ if v != nil {
+ _c.SetTotalRecharged(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -440,6 +510,22 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
+ if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
+ v := user.DefaultBalanceNotifyEnabled
+ _c.mutation.SetBalanceNotifyEnabled(v)
+ }
+ if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
+ v := user.DefaultBalanceNotifyThresholdType
+ _c.mutation.SetBalanceNotifyThresholdType(v)
+ }
+ if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
+ v := user.DefaultBalanceNotifyExtraEmails
+ _c.mutation.SetBalanceNotifyExtraEmails(v)
+ }
+ if _, ok := _c.mutation.TotalRecharged(); !ok {
+ v := user.DefaultTotalRecharged
+ _c.mutation.SetTotalRecharged(v)
+ }
return nil
}
@@ -503,6 +589,18 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
+ if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
+ return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
+ }
+ if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
+ return &ValidationError{Name: "balance_notify_threshold_type", err: errors.New(`ent: missing required field "User.balance_notify_threshold_type"`)}
+ }
+ if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
+ return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)}
+ }
+ if _, ok := _c.mutation.TotalRecharged(); !ok {
+ return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
+ }
return nil
}
@@ -586,6 +684,26 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
+ if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ _node.BalanceNotifyEnabled = value
+ }
+ if value, ok := _c.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ _node.BalanceNotifyThresholdType = value
+ }
+ if value, ok := _c.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ _node.BalanceNotifyThreshold = &value
+ }
+ if value, ok := _c.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ _node.BalanceNotifyExtraEmails = value
+ }
+ if value, ok := _c.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ _node.TotalRecharged = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -988,6 +1106,84 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyEnabled, v)
+ return u
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyEnabled)
+ return u
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsert) SetBalanceNotifyThresholdType(v string) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyThresholdType, v)
+ return u
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyThresholdType() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyThresholdType)
+ return u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyThreshold, v)
+ return u
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyThreshold() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyThreshold)
+ return u
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsert) AddBalanceNotifyThreshold(v float64) *UserUpsert {
+ u.Add(user.FieldBalanceNotifyThreshold, v)
+ return u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsert) ClearBalanceNotifyThreshold() *UserUpsert {
+ u.SetNull(user.FieldBalanceNotifyThreshold)
+ return u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsert) SetBalanceNotifyExtraEmails(v string) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyExtraEmails, v)
+ return u
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyExtraEmails)
+ return u
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsert) SetTotalRecharged(v float64) *UserUpsert {
+ u.Set(user.FieldTotalRecharged, v)
+ return u
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotalRecharged() *UserUpsert {
+ u.SetExcluded(user.FieldTotalRecharged)
+ return u
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
+ u.Add(user.FieldTotalRecharged, v)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1250,6 +1446,97 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyEnabled(v)
+ })
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyEnabled()
+ })
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsertOne) SetBalanceNotifyThresholdType(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThresholdType(v)
+ })
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyThresholdType() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThresholdType()
+ })
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThreshold(v)
+ })
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsertOne) AddBalanceNotifyThreshold(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddBalanceNotifyThreshold(v)
+ })
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyThreshold() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThreshold()
+ })
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsertOne) ClearBalanceNotifyThreshold() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearBalanceNotifyThreshold()
+ })
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsertOne) SetBalanceNotifyExtraEmails(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyExtraEmails(v)
+ })
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyExtraEmails()
+ })
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsertOne) SetTotalRecharged(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotalRecharged(v)
+ })
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsertOne) AddTotalRecharged(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddTotalRecharged(v)
+ })
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotalRecharged()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1678,6 +1965,97 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyEnabled(v)
+ })
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyEnabled()
+ })
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsertBulk) SetBalanceNotifyThresholdType(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThresholdType(v)
+ })
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyThresholdType() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThresholdType()
+ })
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThreshold(v)
+ })
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) AddBalanceNotifyThreshold(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddBalanceNotifyThreshold(v)
+ })
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyThreshold() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThreshold()
+ })
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) ClearBalanceNotifyThreshold() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearBalanceNotifyThreshold()
+ })
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsertBulk) SetBalanceNotifyExtraEmails(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyExtraEmails(v)
+ })
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyExtraEmails()
+ })
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsertBulk) SetTotalRecharged(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotalRecharged(v)
+ })
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsertBulk) AddTotalRecharged(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddTotalRecharged(v)
+ })
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotalRecharged()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 8107c980..6b355247 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -243,6 +243,96 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
+ _u.mutation.SetBalanceNotifyEnabled(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyEnabled(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_u *UserUpdate) SetBalanceNotifyThresholdType(v string) *UserUpdate {
+ _u.mutation.SetBalanceNotifyThresholdType(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyThresholdType(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate {
+ _u.mutation.ResetBalanceNotifyThreshold()
+ _u.mutation.SetBalanceNotifyThreshold(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyThreshold(*v)
+ }
+ return _u
+}
+
+// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
+func (_u *UserUpdate) AddBalanceNotifyThreshold(v float64) *UserUpdate {
+ _u.mutation.AddBalanceNotifyThreshold(v)
+ return _u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (_u *UserUpdate) ClearBalanceNotifyThreshold() *UserUpdate {
+ _u.mutation.ClearBalanceNotifyThreshold()
+ return _u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_u *UserUpdate) SetBalanceNotifyExtraEmails(v string) *UserUpdate {
+ _u.mutation.SetBalanceNotifyExtraEmails(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _u
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (_u *UserUpdate) SetTotalRecharged(v float64) *UserUpdate {
+ _u.mutation.ResetTotalRecharged()
+ _u.mutation.SetTotalRecharged(v)
+ return _u
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotalRecharged(v *float64) *UserUpdate {
+ if v != nil {
+ _u.SetTotalRecharged(*v)
+ }
+ return _u
+}
+
+// AddTotalRecharged adds value to the "total_recharged" field.
+func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
+ _u.mutation.AddTotalRecharged(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -746,6 +836,30 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
+ _spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if _u.mutation.BalanceNotifyThresholdCleared() {
+ _spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
+ }
+ if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedTotalRecharged(); ok {
+ _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1434,6 +1548,96 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyEnabled(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyEnabled(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_u *UserUpdateOne) SetBalanceNotifyThresholdType(v string) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyThresholdType(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyThresholdType(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne {
+ _u.mutation.ResetBalanceNotifyThreshold()
+ _u.mutation.SetBalanceNotifyThreshold(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyThreshold(*v)
+ }
+ return _u
+}
+
+// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) AddBalanceNotifyThreshold(v float64) *UserUpdateOne {
+ _u.mutation.AddBalanceNotifyThreshold(v)
+ return _u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) ClearBalanceNotifyThreshold() *UserUpdateOne {
+ _u.mutation.ClearBalanceNotifyThreshold()
+ return _u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_u *UserUpdateOne) SetBalanceNotifyExtraEmails(v string) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyExtraEmails(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _u
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (_u *UserUpdateOne) SetTotalRecharged(v float64) *UserUpdateOne {
+ _u.mutation.ResetTotalRecharged()
+ _u.mutation.SetTotalRecharged(v)
+ return _u
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotalRecharged(v *float64) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotalRecharged(*v)
+ }
+ return _u
+}
+
+// AddTotalRecharged adds value to the "total_recharged" field.
+func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
+ _u.mutation.AddTotalRecharged(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1967,6 +2171,30 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
+ _spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if _u.mutation.BalanceNotifyThresholdCleared() {
+ _spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
+ }
+ if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedTotalRecharged(); ok {
+ _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/go.sum b/backend/go.sum
index e4496f2c..9312af63 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -183,6 +183,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
+github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
+github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -218,6 +220,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
+github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -251,6 +255,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
+github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
+github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -280,6 +286,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
+github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -312,6 +320,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
+github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
+github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index bc4e5e46..dd9a4e58 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -28,7 +28,7 @@ const (
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
-const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
+const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// UMQ(用户消息队列)模式常量
const (
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index fe181a2f..cf58316c 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
configPath := filepath.Join(tempDir, "config.yaml")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
- require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644))
+ yamlSafePath := filepath.ToSlash(templatePath)
+ require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+yamlSafePath+"\"\n"), 0o644))
t.Setenv("DATA_DIR", tempDir)
cfg, err := Load()
require.NoError(t, err)
- require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
+ require.Equal(t, yamlSafePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
}
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 9e985a79..9883d007 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
+ "details": gin.H{
+ "group_id": mixedErr.GroupID,
+ "group_name": mixedErr.GroupName,
+ "current_platform": mixedErr.CurrentPlatform,
+ "other_platform": mixedErr.OtherPlatform,
+ },
})
return
}
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
index c92b35bb..9151d018 100644
--- a/backend/internal/handler/admin/channel_handler.go
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -1,6 +1,7 @@
package admin
import (
+ "fmt"
"strconv"
"strings"
@@ -26,24 +27,32 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
// --- Request / Response types ---
type createChannelRequest struct {
- Name string `json:"name" binding:"required,max=100"`
- Description string `json:"description"`
- GroupIDs []int64 `json:"group_ids"`
- ModelPricing []channelModelPricingRequest `json:"model_pricing"`
- ModelMapping map[string]map[string]string `json:"model_mapping"`
- BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
- RestrictModels bool `json:"restrict_models"`
+ Name string `json:"name" binding:"required,max=100"`
+ Description string `json:"description"`
+ GroupIDs []int64 `json:"group_ids"`
+ ModelPricing []channelModelPricingRequest `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
+ RestrictModels bool `json:"restrict_models"`
+ Features string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
+ ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
}
type updateChannelRequest struct {
- Name string `json:"name" binding:"omitempty,max=100"`
- Description *string `json:"description"`
- Status string `json:"status" binding:"omitempty,oneof=active disabled"`
- GroupIDs *[]int64 `json:"group_ids"`
- ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
- ModelMapping map[string]map[string]string `json:"model_mapping"`
- BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
- RestrictModels *bool `json:"restrict_models"`
+ Name string `json:"name" binding:"omitempty,max=100"`
+ Description *string `json:"description"`
+ Status string `json:"status" binding:"omitempty,oneof=active disabled"`
+ GroupIDs *[]int64 `json:"group_ids"`
+ ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
+ RestrictModels *bool `json:"restrict_models"`
+ Features *string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
+ ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
}
type channelModelPricingRequest struct {
@@ -71,18 +80,29 @@ type pricingIntervalRequest struct {
SortOrder int `json:"sort_order"`
}
+type accountStatsPricingRuleRequest struct {
+ Name string `json:"name"`
+ GroupIDs []int64 `json:"group_ids"`
+ AccountIDs []int64 `json:"account_ids"`
+ Pricing []channelModelPricingRequest `json:"pricing"`
+}
+
type channelResponse struct {
- ID int64 `json:"id"`
- Name string `json:"name"`
- Description string `json:"description"`
- Status string `json:"status"`
- BillingModelSource string `json:"billing_model_source"`
- RestrictModels bool `json:"restrict_models"`
- GroupIDs []int64 `json:"group_ids"`
- ModelPricing []channelModelPricingResponse `json:"model_pricing"`
- ModelMapping map[string]map[string]string `json:"model_mapping"`
- CreatedAt string `json:"created_at"`
- UpdatedAt string `json:"updated_at"`
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Status string `json:"status"`
+ BillingModelSource string `json:"billing_model_source"`
+ RestrictModels bool `json:"restrict_models"`
+ Features string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
+ GroupIDs []int64 `json:"group_ids"`
+ ModelPricing []channelModelPricingResponse `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
}
type channelModelPricingResponse struct {
@@ -112,6 +132,14 @@ type pricingIntervalResponse struct {
SortOrder int `json:"sort_order"`
}
+type accountStatsPricingRuleResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ GroupIDs []int64 `json:"group_ids"`
+ AccountIDs []int64 `json:"account_ids"`
+ Pricing []channelModelPricingResponse `json:"pricing"`
+}
+
func channelToResponse(ch *service.Channel) *channelResponse {
if ch == nil {
return nil
@@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Description: ch.Description,
Status: ch.Status,
RestrictModels: ch.RestrictModels,
+ Features: ch.Features,
+ FeaturesConfig: ch.FeaturesConfig,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
@@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
for _, p := range ch.ModelPricing {
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
}
+
+ resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
+ resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
+ for _, rule := range ch.AccountStatsPricingRules {
+ ruleResp := accountStatsPricingRuleResponse{
+ ID: rule.ID,
+ Name: rule.Name,
+ GroupIDs: rule.GroupIDs,
+ AccountIDs: rule.AccountIDs,
+ Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
+ }
+ if ruleResp.GroupIDs == nil {
+ ruleResp.GroupIDs = []int64{}
+ }
+ if ruleResp.AccountIDs == nil {
+ ruleResp.AccountIDs = []int64{}
+ }
+ for i := range rule.Pricing {
+ ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
+ }
+ resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
+ }
+
return resp
}
@@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
billingMode = service.BillingModeToken
}
platform := r.Platform
- if platform == "" {
- platform = service.PlatformAnthropic
- }
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals {
intervals = append(intervals, service.PricingInterval{
@@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result
}
+func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
+ return service.AccountStatsPricingRule{
+ Name: r.Name,
+ GroupIDs: r.GroupIDs,
+ AccountIDs: r.AccountIDs,
+ Pricing: pricingRequestToService(r.Pricing),
+ }
+}
+
// --- Handlers ---
// List handles listing channels with pagination
@@ -291,15 +350,42 @@ func (h *ChannelHandler) Create(c *gin.Context) {
}
pricing := pricingRequestToService(req.ModelPricing)
+ // Main model_pricing requires a platform; default to anthropic for backward compatibility.
+ for i := range pricing {
+ if pricing[i].Platform == "" {
+ pricing[i].Platform = service.PlatformAnthropic
+ }
+ }
+
+ var statsRules []service.AccountStatsPricingRule
+ for i, r := range req.AccountStatsPricingRules {
+ if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
+ fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
+ return
+ }
+ if len(r.Pricing) == 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
+ fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
+ return
+ }
+ rule := accountStatsPricingRuleRequestToService(r)
+ rule.SortOrder = i
+ statsRules = append(statsRules, rule)
+ }
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
- Name: req.Name,
- Description: req.Description,
- GroupIDs: req.GroupIDs,
- ModelPricing: pricing,
- ModelMapping: req.ModelMapping,
- BillingModelSource: req.BillingModelSource,
- RestrictModels: req.RestrictModels,
+ Name: req.Name,
+ Description: req.Description,
+ GroupIDs: req.GroupIDs,
+ ModelPricing: pricing,
+ ModelMapping: req.ModelMapping,
+ BillingModelSource: req.BillingModelSource,
+ RestrictModels: req.RestrictModels,
+ Features: req.Features,
+ FeaturesConfig: req.FeaturesConfig,
+ ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
+ AccountStatsPricingRules: statsRules,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -325,18 +411,45 @@ func (h *ChannelHandler) Update(c *gin.Context) {
}
input := &service.UpdateChannelInput{
- Name: req.Name,
- Description: req.Description,
- Status: req.Status,
- GroupIDs: req.GroupIDs,
- ModelMapping: req.ModelMapping,
- BillingModelSource: req.BillingModelSource,
- RestrictModels: req.RestrictModels,
+ Name: req.Name,
+ Description: req.Description,
+ Status: req.Status,
+ GroupIDs: req.GroupIDs,
+ ModelMapping: req.ModelMapping,
+ BillingModelSource: req.BillingModelSource,
+ RestrictModels: req.RestrictModels,
+ Features: req.Features,
+ FeaturesConfig: req.FeaturesConfig,
+ ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
+ for i := range pricing {
+ if pricing[i].Platform == "" {
+ pricing[i].Platform = service.PlatformAnthropic
+ }
+ }
input.ModelPricing = &pricing
}
+ if req.AccountStatsPricingRules != nil {
+ statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
+ for i, r := range *req.AccountStatsPricingRules {
+ if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
+ fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
+ return
+ }
+ if len(r.Pricing) == 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
+ fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
+ return
+ }
+ rule := accountStatsPricingRuleRequestToService(r)
+ rule.SortOrder = i
+ statsRules = append(statsRules, rule)
+ }
+ input.AccountStatsPricingRules = &statsRules
+ }
channel, err := h.channelService.Update(c.Request.Context(), id, input)
if err != nil {
diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go
index 2f4b4440..f218cce4 100644
--- a/backend/internal/handler/admin/channel_handler_test.go
+++ b/backend/internal/handler/admin/channel_handler_test.go
@@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) {
wantValue: string(service.BillingModeToken),
},
{
- name: "empty platform defaults to anthropic",
+ name: "empty platform stays empty",
req: channelModelPricingRequest{
Models: []string{"m1"},
Platform: "",
},
wantField: "Platform",
- wantValue: "anthropic",
+ wantValue: "",
},
}
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index ba751131..29c97b4b 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -5,11 +5,10 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
- "log"
+ "log/slog"
"net/http"
"regexp"
"strings"
- "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
+ WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
+ AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
@@ -304,6 +309,13 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
+ // Balance low notification
+ BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
+ BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
+ AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
+ AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
+
// Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"`
@@ -881,6 +893,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
+ BalanceLowNotifyEnabled: func() bool {
+ if req.BalanceLowNotifyEnabled != nil {
+ return *req.BalanceLowNotifyEnabled
+ }
+ return previousSettings.BalanceLowNotifyEnabled
+ }(),
+ BalanceLowNotifyThreshold: func() float64 {
+ if req.BalanceLowNotifyThreshold != nil {
+ return *req.BalanceLowNotifyThreshold
+ }
+ return previousSettings.BalanceLowNotifyThreshold
+ }(),
+ BalanceLowNotifyRechargeURL: func() string {
+ if req.BalanceLowNotifyRechargeURL != nil {
+ return *req.BalanceLowNotifyRechargeURL
+ }
+ return previousSettings.BalanceLowNotifyRechargeURL
+ }(),
+ AccountQuotaNotifyEnabled: func() bool {
+ if req.AccountQuotaNotifyEnabled != nil {
+ return *req.AccountQuotaNotifyEnabled
+ }
+ return previousSettings.AccountQuotaNotifyEnabled
+ }(),
+ AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry {
+ if req.AccountQuotaNotifyEmails != nil {
+ return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails)
+ }
+ return previousSettings.AccountQuotaNotifyEmails
+ }(),
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -1027,6 +1069,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
+ AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
@@ -1073,11 +1120,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
- log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
- time.Now().UTC().Format(time.RFC3339),
- subject.UserID,
- role,
- changed,
+ slog.Info("settings updated",
+ "audit", true,
+ "user_id", subject.UserID,
+ "role", role,
+ "changed", changed,
)
}
@@ -1092,6 +1139,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist")
}
+ if before.PromoCodeEnabled != after.PromoCodeEnabled {
+ changed = append(changed, "promo_code_enabled")
+ }
+ if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
+ changed = append(changed, "invitation_code_enabled")
+ }
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
@@ -1302,6 +1355,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items")
}
+ if before.CustomEndpoints != after.CustomEndpoints {
+ changed = append(changed, "custom_endpoints")
+ }
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
changed = append(changed, "enable_fingerprint_unification")
}
@@ -1311,6 +1367,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
+ // Balance & quota notification
+ if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
+ changed = append(changed, "balance_low_notify_enabled")
+ }
+ if before.BalanceLowNotifyThreshold != after.BalanceLowNotifyThreshold {
+ changed = append(changed, "balance_low_notify_threshold")
+ }
+ if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
+ changed = append(changed, "balance_low_notify_recharge_url")
+ }
+ if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
+ changed = append(changed, "account_quota_notify_enabled")
+ }
+ if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
+ changed = append(changed, "account_quota_notify_emails")
+ }
return changed
}
@@ -1367,6 +1439,18 @@ func equalIntSlice(a, b []int) bool {
return true
}
+func equalNotifyEmailEntries(a, b []service.NotifyEmailEntry) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i].Email != b[i].Email || a[i].Verified != b[i].Verified || a[i].Disabled != b[i].Disabled {
+ return false
+ }
+ }
+ return true
+}
+
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host"`
@@ -1847,3 +1931,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}
+
+// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
+// GET /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
+ cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg))
+}
+
+// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
+// PUT /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
+ var cfg service.WebSearchEmulationConfig
+ if err := c.ShouldBindJSON(&cfg); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Re-read (with sanitized api keys) to return current state
+ updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated))
+}
+
+// ResetWebSearchUsage 重置指定 provider 的配额用量
+// POST /api/v1/admin/settings/web-search-emulation/reset-usage
+func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) {
+ var req struct {
+ ProviderType string `json:"provider_type"`
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if req.ProviderType == "" {
+ response.BadRequest(c, "provider_type is required")
+ return
+ }
+ if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, nil)
+}
+
+// TestWebSearchEmulation 测试 Web Search 搜索
+// POST /api/v1/admin/settings/web-search-emulation/test
+func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
+ var req struct {
+ Query string `json:"query"`
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if strings.TrimSpace(req.Query) == "" {
+ req.Query = "搜索今年世界大事件"
+ }
+
+ result, err := service.TestWebSearch(c.Request.Context(), req.Query)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 478600eb..d2ccb8d6 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -13,16 +13,21 @@ func UserFromServiceShallow(u *service.User) *User {
return nil
}
return &User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- AllowedGroups: u.AllowedGroups,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ AllowedGroups: u.AllowedGroups,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
+ BalanceNotifyEnabled: u.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: u.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
+ TotalRecharged: u.TotalRecharged,
}
}
@@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account {
out.QuotaWeeklyResetAt = &v
}
}
+
+ // 配额通知配置
+ if enabled := a.GetQuotaNotifyDailyEnabled(); enabled {
+ out.QuotaNotifyDailyEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 {
+ out.QuotaNotifyDailyThreshold = &threshold
+ }
+ if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled {
+ out.QuotaNotifyWeeklyEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 {
+ out.QuotaNotifyWeeklyThreshold = &threshold
+ }
+ if enabled := a.GetQuotaNotifyTotalEnabled(); enabled {
+ out.QuotaNotifyTotalEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 {
+ out.QuotaNotifyTotalThreshold = &threshold
+ }
}
return out
@@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
ModelMappingChain: l.ModelMappingChain,
BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier,
+ AccountStatsCost: l.AccountStatsCost,
IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account),
}
diff --git a/backend/internal/handler/dto/notify_email_entry.go b/backend/internal/handler/dto/notify_email_entry.go
new file mode 100644
index 00000000..78641005
--- /dev/null
+++ b/backend/internal/handler/dto/notify_email_entry.go
@@ -0,0 +1,43 @@
+package dto
+
+import "github.com/Wei-Shaw/sub2api/internal/service"
+
+// NotifyEmailEntry represents a notification email with enable/disable and verification state.
+// All emails are user-managed; maximum 3 entries per user.
+type NotifyEmailEntry struct {
+ Email string `json:"email"`
+ Disabled bool `json:"disabled"`
+ Verified bool `json:"verified"`
+}
+
+// NotifyEmailEntriesFromService converts service entries to DTO entries.
+func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry {
+ if entries == nil {
+ return nil
+ }
+ result := make([]NotifyEmailEntry, len(entries))
+ for i, e := range entries {
+ result[i] = NotifyEmailEntry{
+ Email: e.Email,
+ Disabled: e.Disabled,
+ Verified: e.Verified,
+ }
+ }
+ return result
+}
+
+// NotifyEmailEntriesToService converts DTO entries to service entries.
+func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry {
+ if entries == nil {
+ return nil
+ }
+ result := make([]service.NotifyEmailEntry, len(entries))
+ for i, e := range entries {
+ result[i] = service.NotifyEmailEntry{
+ Email: e.Email,
+ Disabled: e.Disabled,
+ Verified: e.Verified,
+ }
+ }
+ return result
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index cbbe9216..ef285a44 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -124,6 +124,9 @@ type SystemSettings struct {
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"`
+ // Web Search Emulation
+ WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
+
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
@@ -145,6 +148,13 @@ type SystemSettings struct {
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
+
+ // Balance low notification
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
+ BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
+ AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
+ AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
}
type DefaultSubscriptionSetting struct {
@@ -183,6 +193,10 @@ type PublicSettings struct {
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
+ BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
+ BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index e026ca65..8c1e166f 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -18,6 +18,13 @@ type User struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
+ // 余额不足通知
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+ BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
+ TotalRecharged float64 `json:"total_recharged"`
+
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
@@ -218,6 +225,14 @@ type Account struct {
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
+ // 配额通知配置
+ QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"`
+ QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"`
+ QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"`
+ QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"`
+ QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"`
+ QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"`
+
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -412,6 +427,8 @@ type AdminUsageLog struct {
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
+ // AccountStatsCost 自定义定价规则计算的账号统计费用(nil 表示使用默认公式)
+ AccountStatsCost *float64 `json:"account_stats_cost,omitempty"`
// IPAddress 用户请求 IP(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 59619d50..f5eff8c9 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
+ // 设置请求所属分组 ID(用于渠道级功能判断,如 WebSearch 模拟)
+ parsedReq.GroupID = apiKey.GroupID
+
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
@@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
+ ParsedRequest: parsedReq,
APIKey: apiKey,
User: apiKey.User,
Account: account,
@@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0))
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 转发请求 - 根据账号平台分流
+ c.Set("parsed_request", parsedReq)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
@@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
+ ParsedRequest: parsedReq,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
index acea3780..1fdc46ba 100644
--- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
+++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
@@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
+ nil, // balanceNotifyService
)
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 0425fc49..5fde86fa 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) {
response.Success(c, gin.H{"message": "refund requested"})
}
+// GetRefundEligibleProviders returns provider instance IDs that allow user refund.
+func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) {
+ ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"provider_instance_ids": ids})
+}
+
// VerifyOrderRequest is the request body for verifying a payment order.
type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"`
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 54a92a8c..1717b7a1 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
Version: h.version,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
+ BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
})
}
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 35862f1c..2535ea5e 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -11,13 +11,17 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
- userService *service.UserService
+ userService *service.UserService
+ emailService *service.EmailService
+ emailCache service.EmailCache
}
// NewUserHandler creates a new UserHandler
-func NewUserHandler(userService *service.UserService) *UserHandler {
+func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
return &UserHandler{
- userService: userService,
+ userService: userService,
+ emailService: emailService,
+ emailCache: emailCache,
}
}
@@ -29,7 +33,9 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
- Username *string `json:"username"`
+ Username *string `json:"username"`
+ BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// GetProfile handles getting user profile
@@ -94,7 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
}
svcReq := service.UpdateProfileRequest{
- Username: req.Username,
+ Username: req.Username,
+ BalanceNotifyEnabled: req.BalanceNotifyEnabled,
+ BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
@@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, dto.UserFromService(updatedUser))
}
+
+// SendNotifyEmailCodeRequest represents the request to send notify email verification code
+type SendNotifyEmailCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// SendNotifyEmailCode sends verification code to extra notification email
+// POST /api/v1/user/notify-email/send-code
+func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req SendNotifyEmailCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Verification code sent successfully"})
+}
+
+// VerifyNotifyEmailRequest represents the request to verify and add notify email
+type VerifyNotifyEmailRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Code string `json:"code" binding:"required,len=6"`
+}
+
+// VerifyNotifyEmail verifies code and adds email to notification list
+// POST /api/v1/user/notify-email/verify
+func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req VerifyNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Return updated user
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.UserFromService(updatedUser))
+}
+
+// RemoveNotifyEmailRequest represents the request to remove a notify email
+type RemoveNotifyEmailRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// RemoveNotifyEmail removes email from notification list
+// DELETE /api/v1/user/notify-email
+func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req RemoveNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Return updated user
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.UserFromService(updatedUser))
+}
+
+// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
+type ToggleNotifyEmailRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Disabled bool `json:"disabled"`
+}
+
+// ToggleNotifyEmail toggles the disabled state of a notification email
+// PUT /api/v1/user/notify-email/toggle
+func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req ToggleNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.UserFromService(updatedUser))
+}
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
index 55cb2043..f0353173 100644
--- a/backend/internal/payment/load_balancer.go
+++ b/backend/internal/payment/load_balancer.go
@@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
var matched []*dbent.PaymentProviderInstance
for _, inst := range instances {
- if InstanceSupportsType(inst.SupportedTypes, paymentType) {
+ // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
+ // not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
+ if paymentType == TypeStripe {
+ if inst.ProviderKey == TypeStripe {
+ matched = append(matched, inst)
+ }
+ } else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
matched = append(matched, inst)
}
}
diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go
index 568b56a3..04b3c25b 100644
--- a/backend/internal/payment/load_balancer_test.go
+++ b/backend/internal/payment/load_balancer_test.go
@@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) {
wantIDs: nil,
},
{
- name: "empty candidates returns empty",
+ name: "empty candidates returns empty",
candidates: nil,
paymentType: "alipay",
orderAmount: 10,
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
index 1b9d66ba..7b0ce0d8 100644
--- a/backend/internal/payment/provider/alipay_test.go
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) {
errSubstr: "privateKey",
},
{
- name: "nil config map returns error for appId",
- config: map[string]string{},
- wantErr: true,
+ name: "nil config map returns error for appId",
+ config: map[string]string{},
+ wantErr: true,
errSubstr: "appId",
},
}
diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go
index 58982878..4a68f3a9 100644
--- a/backend/internal/pkg/antigravity/stream_transformer.go
+++ b/backend/internal/pkg/antigravity/stream_transformer.go
@@ -18,6 +18,9 @@ const (
BlockTypeFunction
)
+// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
+type UsageMapHook func(usageMap map[string]any)
+
// StreamingProcessor 流式响应处理器
type StreamingProcessor struct {
blockType BlockType
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
+ usageMapHook UsageMapHook
// 累计 usage
inputTokens int
@@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
}
}
+// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
+func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
+ p.usageMapHook = fn
+}
+
+func usageToMap(u ClaudeUsage) map[string]any {
+ m := map[string]any{
+ "input_tokens": u.InputTokens,
+ "output_tokens": u.OutputTokens,
+ }
+ if u.CacheCreationInputTokens > 0 {
+ m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
+ }
+ if u.CacheReadInputTokens > 0 {
+ m["cache_read_input_tokens"] = u.CacheReadInputTokens
+ }
+ if u.ImageOutputTokens > 0 {
+ m["image_output_tokens"] = u.ImageOutputTokens
+ }
+ return m
+}
+
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
@@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID()
}
+ var usageValue any = usage
+ if p.usageMapHook != nil {
+ usageMap := usageToMap(usage)
+ p.usageMapHook(usageMap)
+ usageValue = usageMap
+ }
+
message := map[string]any{
"id": responseID,
"type": "message",
@@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel,
"stop_reason": nil,
"stop_sequence": nil,
- "usage": usage,
+ "usage": usageValue,
}
event := map[string]any{
@@ -492,13 +525,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
ImageOutputTokens: p.imageOutputTokens,
}
+ var usageValue any = usage
+ if p.usageMapHook != nil {
+ usageMap := usageToMap(usage)
+ p.usageMapHook(usageMap)
+ usageValue = usageMap
+ }
+
deltaEvent := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
- "usage": usage,
+ "usage": usageValue,
}
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
index dc157a6d..c2725406 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
@@ -27,13 +27,14 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
}
out := &ResponsesRequest{
- Model: req.Model,
- Input: inputJSON,
- Temperature: req.Temperature,
- TopP: req.TopP,
- Stream: true, // upstream always streams
- Include: []string{"reasoning.encrypted_content"},
- ServiceTier: req.ServiceTier,
+ Model: req.Model,
+ Instructions: req.Instructions,
+ Input: inputJSON,
+ Temperature: req.Temperature,
+ TopP: req.TopP,
+ Stream: true, // upstream always streams
+ Include: []string{"reasoning.encrypted_content"},
+ ServiceTier: req.ServiceTier,
}
storeFalse := false
diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go
index b383f867..e0d1a53e 100644
--- a/backend/internal/pkg/apicompat/types.go
+++ b/backend/internal/pkg/apicompat/types.go
@@ -152,6 +152,7 @@ type AnthropicDelta struct {
// ResponsesRequest is the request body for POST /v1/responses.
type ResponsesRequest struct {
Model string `json:"model"`
+ Instructions string `json:"instructions,omitempty"`
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
@@ -337,6 +338,7 @@ type ResponsesStreamEvent struct {
type ChatCompletionsRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
+ Instructions string `json:"instructions,omitempty"` // OpenAI Responses API compat
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
diff --git a/backend/internal/pkg/logger/logger_test.go b/backend/internal/pkg/logger/logger_test.go
index 74aae061..06a277a4 100644
--- a/backend/internal/pkg/logger/logger_test.go
+++ b/backend/internal/pkg/logger/logger_test.go
@@ -10,7 +10,13 @@ import (
)
func TestInit_DualOutput(t *testing.T) {
- tmpDir := t.TempDir()
+ // Use os.MkdirTemp instead of t.TempDir to avoid cleanup failures
+ // when lumberjack holds file handles on Windows.
+ tmpDir, err := os.MkdirTemp("", "logger-test-*")
+ if err != nil {
+ t.Fatalf("create temp dir: %v", err)
+ }
+ t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
origStdout := os.Stdout
@@ -57,7 +63,9 @@ func TestInit_DualOutput(t *testing.T) {
L().Info("dual-output-info")
L().Warn("dual-output-warn")
- Sync()
+
+ // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
+ // The log data is already in the pipe buffer; closing writers is sufficient.
_ = stdoutW.Close()
_ = stderrW.Close()
@@ -166,7 +174,9 @@ func TestInit_CallerShouldPointToCallsite(t *testing.T) {
}
L().Info("caller-check")
- Sync()
+ // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
+ os.Stdout = origStdout
+ os.Stderr = origStderr
_ = stdoutW.Close()
logBytes, _ := io.ReadAll(stdoutR)
diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go
index 4482a2ec..30d25b33 100644
--- a/backend/internal/pkg/logger/stdlog_bridge_test.go
+++ b/backend/internal/pkg/logger/stdlog_bridge_test.go
@@ -77,7 +77,7 @@ func TestStdLogBridgeRoutesLevels(t *testing.T) {
log.Printf("service started")
log.Printf("Warning: queue full")
log.Printf("Forward request failed: timeout")
- Sync()
+ // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_ = stdoutW.Close()
_ = stderrW.Close()
@@ -139,7 +139,7 @@ func TestLegacyPrintfRoutesLevels(t *testing.T) {
LegacyPrintf("service.test", "request started")
LegacyPrintf("service.test", "Warning: queue full")
LegacyPrintf("service.test", "forward failed: timeout")
- Sync()
+ // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_ = stdoutW.Close()
_ = stderrW.Close()
diff --git a/backend/internal/pkg/websearch/brave.go b/backend/internal/pkg/websearch/brave.go
new file mode 100644
index 00000000..707e7029
--- /dev/null
+++ b/backend/internal/pkg/websearch/brave.go
@@ -0,0 +1,106 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+)
+
+const (
+ braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
+ braveMaxCount = 20
+ braveProviderName = "brave"
+)
+
+// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
+var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
+
+// BraveProvider implements web search via the Brave Search API.
+type BraveProvider struct {
+ apiKey string
+ httpClient *http.Client
+}
+
+// NewBraveProvider creates a Brave Search provider.
+// The caller is responsible for configuring the http.Client with proxy/timeouts.
+func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
+}
+
+func (b *BraveProvider) Name() string { return braveProviderName }
+
+func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
+ count := req.MaxResults
+ if count <= 0 {
+ count = defaultMaxResults
+ }
+ if count > braveMaxCount {
+ count = braveMaxCount
+ }
+
+ u := *braveSearchURL // copy the pre-parsed URL
+ q := u.Query()
+ q.Set("q", req.Query)
+ q.Set("count", strconv.Itoa(count))
+ u.RawQuery = q.Encode()
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("brave: build request: %w", err)
+ }
+ httpReq.Header.Set("X-Subscription-Token", b.apiKey)
+ httpReq.Header.Set("Accept", "application/json")
+
+ resp, err := b.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("brave: request failed: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
+ if err != nil {
+ return nil, fmt.Errorf("brave: read body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
+ }
+
+ var raw braveResponse
+ if err := json.Unmarshal(body, &raw); err != nil {
+ return nil, fmt.Errorf("brave: decode response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(raw.Web.Results))
+ for _, r := range raw.Web.Results {
+ results = append(results, SearchResult{
+ URL: r.URL,
+ Title: r.Title,
+ Snippet: r.Description,
+ PageAge: r.Age,
+ })
+ }
+
+ return &SearchResponse{Results: results, Query: req.Query}, nil
+}
+
+// braveResponse is the minimal structure of the Brave Search API response.
+type braveResponse struct {
+ Web struct {
+ Results []braveResult `json:"results"`
+ } `json:"web"`
+}
+
+type braveResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Description string `json:"description"`
+ Age string `json:"age"`
+}
diff --git a/backend/internal/pkg/websearch/brave_test.go b/backend/internal/pkg/websearch/brave_test.go
new file mode 100644
index 00000000..4dc5b219
--- /dev/null
+++ b/backend/internal/pkg/websearch/brave_test.go
@@ -0,0 +1,119 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestBraveProvider_Name(t *testing.T) {
+ p := NewBraveProvider("key", nil)
+ require.Equal(t, "brave", p.Name())
+}
+
+func TestBraveProvider_Search_Success(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
+ require.Equal(t, "application/json", r.Header.Get("Accept"))
+ require.Equal(t, "golang", r.URL.Query().Get("q"))
+ require.Equal(t, "3", r.URL.Query().Get("count"))
+
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{
+ {URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
+ {URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
+ {URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("test-key", srv.Client())
+ // Override the endpoint for testing
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
+ require.NoError(t, err)
+ require.Len(t, resp.Results, 3)
+ require.Equal(t, "https://go.dev", resp.Results[0].URL)
+ require.Equal(t, "Go lang", resp.Results[0].Snippet)
+ require.Equal(t, "1 day", resp.Results[0].PageAge)
+}
+
+func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
+ var receivedCount string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ receivedCount = r.URL.Query().Get("count")
+ resp := braveResponse{}
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
+ require.Equal(t, "5", receivedCount)
+}
+
+func TestBraveProvider_Search_HTTPError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(429)
+ _, _ = w.Write([]byte("rate limited"))
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "brave: status 429")
+}
+
+func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ _, _ = w.Write([]byte("not json"))
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "brave: decode response")
+}
+
+func TestBraveProvider_Search_EmptyResults(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Empty(t, resp.Results)
+}
diff --git a/backend/internal/pkg/websearch/helpers.go b/backend/internal/pkg/websearch/helpers.go
new file mode 100644
index 00000000..0d08b749
--- /dev/null
+++ b/backend/internal/pkg/websearch/helpers.go
@@ -0,0 +1,14 @@
+package websearch
+
+const (
+ maxResponseSize = 1 << 20 // 1 MB
+ errorBodyTruncLen = 200
+)
+
+// truncateBody returns a truncated string of body for error messages.
+func truncateBody(body []byte) string {
+ if len(body) <= errorBodyTruncLen {
+ return string(body)
+ }
+ return string(body[:errorBodyTruncLen]) + "...(truncated)"
+}
diff --git a/backend/internal/pkg/websearch/helpers_test.go b/backend/internal/pkg/websearch/helpers_test.go
new file mode 100644
index 00000000..e3164329
--- /dev/null
+++ b/backend/internal/pkg/websearch/helpers_test.go
@@ -0,0 +1,25 @@
+package websearch
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTruncateBody_Short(t *testing.T) {
+ body := []byte("short body")
+ require.Equal(t, "short body", truncateBody(body))
+}
+
+func TestTruncateBody_Long(t *testing.T) {
+ body := []byte(strings.Repeat("x", 500))
+ result := truncateBody(body)
+ require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
+ require.True(t, strings.HasSuffix(result, "...(truncated)"))
+}
+
+func TestTruncateBody_ExactBoundary(t *testing.T) {
+ body := []byte(strings.Repeat("x", errorBodyTruncLen))
+ require.Equal(t, string(body), truncateBody(body))
+}
diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go
new file mode 100644
index 00000000..307aa1e9
--- /dev/null
+++ b/backend/internal/pkg/websearch/manager.go
@@ -0,0 +1,528 @@
+package websearch
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "log/slog"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
+ "github.com/redis/go-redis/v9"
+)
+
+// ProviderConfig holds the configuration for a single search provider.
+type ProviderConfig struct {
+ Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
+ APIKey string `json:"api_key"` // secret
+ QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
+ SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly from this date
+ ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
+ ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
+ ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
+}
+
+// Manager selects providers by quota-weighted load balancing and tracks quota via Redis.
+type Manager struct {
+ configs []ProviderConfig
+ redis *redis.Client
+
+ clientMu sync.Mutex
+ clientCache map[string]*http.Client
+}
+
+// Timeout constants for proxy and search operations.
+const (
+ proxyDialTimeout = 3 * time.Second // proxy TCP connection timeout
+ proxyTLSTimeout = 3 * time.Second // TLS handshake timeout
+ searchDataTimeout = 60 * time.Second // response data transfer timeout
+ searchRequestTimeout = searchDataTimeout + proxyDialTimeout
+
+ quotaKeyPrefix = "websearch:quota:"
+ proxyUnavailableKey = "websearch:proxy_unavailable:%d"
+ proxyUnavailableTTL = 5 * time.Minute
+ quotaTTLBuffer = 24 * time.Hour
+ defaultQuotaTTL = 31*24*time.Hour + quotaTTLBuffer // fallback when no subscription date
+ maxCachedClients = 100
+)
+
+// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
+// Callers may use this to trigger account switching instead of direct fallback.
+var ErrProxyUnavailable = errors.New("websearch: proxy unavailable")
+
+// quotaIncrScript atomically increments the counter and sets TTL on first creation.
+var quotaIncrScript = redis.NewScript(`
+local val = redis.call('INCR', KEYS[1])
+if val == 1 then
+ redis.call('EXPIRE', KEYS[1], ARGV[1])
+else
+ local ttl = redis.call('TTL', KEYS[1])
+ if ttl == -1 then
+ redis.call('EXPIRE', KEYS[1], ARGV[1])
+ end
+end
+return val
+`)
+
+// NewManager creates a Manager with the given provider configs and Redis client.
+// Provider order is preserved as-is; selectByQuotaWeight handles load balancing.
+func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
+ copied := make([]ProviderConfig, len(configs))
+ copy(copied, configs)
+ return &Manager{
+ configs: copied,
+ redis: redisClient,
+ clientCache: make(map[string]*http.Client),
+ }
+}
+
+// SearchWithBestProvider selects a provider using quota-weighted load balancing,
+// reserves quota, executes the search, and rolls back quota on failure.
+// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes.
+func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
+ if strings.TrimSpace(req.Query) == "" {
+ return nil, "", fmt.Errorf("websearch: empty search query")
+ }
+
+ candidates := m.filterAvailableProviders(ctx, req.ProxyURL)
+ if len(candidates) == 0 {
+ return nil, "", fmt.Errorf("websearch: no available provider (all exhausted, expired, or proxy unavailable)")
+ }
+
+ selected := m.selectByQuotaWeight(ctx, candidates)
+
+ for _, cfg := range selected {
+ allowed, incremented := m.tryReserveQuota(ctx, cfg)
+ if !allowed {
+ continue
+ }
+ resp, err := m.executeSearch(ctx, cfg, req)
+ if err != nil {
+ if incremented {
+ m.rollbackQuota(ctx, cfg)
+ }
+ if isProxyError(err) {
+ m.markProxyUnavailable(ctx, cfg, req.ProxyURL)
+ if req.ProxyURL != "" {
+ // Account-level proxy is shared by all providers — no point
+ // trying others with the same broken proxy; signal account switch.
+ slog.Warn("websearch: account proxy error, aborting failover",
+ "provider", cfg.Type, "error", err)
+ return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
+ }
+ // Provider-specific proxy failed — try the next provider which
+ // may use a different (or no) proxy.
+ slog.Warn("websearch: provider proxy error, trying next provider",
+ "provider", cfg.Type, "error", err)
+ continue
+ }
+ slog.Warn("websearch: provider search failed",
+ "provider", cfg.Type, "error", err)
+ continue
+ }
+ return resp, cfg.Type, nil
+ }
+ return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
+}
+
+// filterAvailableProviders returns providers that have API keys, are not expired,
+// and whose proxies are not marked unavailable.
+func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL string) []ProviderConfig {
+ var out []ProviderConfig
+ for _, cfg := range m.configs {
+ if !m.isProviderAvailable(cfg) {
+ continue
+ }
+ proxyID := resolveProxyID(cfg, accountProxyURL)
+ if proxyID > 0 && !m.isProxyAvailable(ctx, proxyID) {
+ slog.Debug("websearch: proxy marked unavailable, skipping",
+ "provider", cfg.Type, "proxy_id", proxyID)
+ continue
+ }
+ out = append(out, cfg)
+ }
+ return out
+}
+
+// weighted is a provider candidate with computed quota weight.
+type weighted struct {
+ cfg ProviderConfig
+ weight int64
+}
+
+// selectByQuotaWeight orders candidates by remaining quota weight.
+// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
+// Among providers with quota, higher remaining quota = higher priority.
+func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
+ items := m.computeWeights(ctx, candidates)
+ withQuota, withoutQuota := partitionByQuota(items)
+ sortByStableRandomWeight(withQuota)
+ return mergeWeightedResults(withQuota, withoutQuota, len(candidates))
+}
+
+func (m *Manager) computeWeights(ctx context.Context, candidates []ProviderConfig) []weighted {
+ items := make([]weighted, 0, len(candidates))
+ for _, cfg := range candidates {
+ w := int64(0)
+ if cfg.QuotaLimit > 0 {
+ used, _ := m.GetUsage(ctx, cfg.Type)
+ if remaining := cfg.QuotaLimit - used; remaining > 0 {
+ w = remaining
+ }
+ }
+ items = append(items, weighted{cfg: cfg, weight: w})
+ }
+ return items
+}
+
+func partitionByQuota(items []weighted) (withQuota, withoutQuota []weighted) {
+ for _, item := range items {
+ if item.weight > 0 {
+ withQuota = append(withQuota, item)
+ } else {
+ withoutQuota = append(withoutQuota, item)
+ }
+ }
+ return
+}
+
+// sortByStableRandomWeight assigns a fixed random factor to each item before sorting,
+// ensuring deterministic sort behavior (transitivity) within a single call.
+func sortByStableRandomWeight(items []weighted) {
+ if len(items) <= 1 {
+ return
+ }
+ type entry struct {
+ item weighted
+ factor float64
+ }
+ entries := make([]entry, len(items))
+ for i, item := range items {
+ entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())}
+ }
+ sort.Slice(entries, func(i, j int) bool {
+ return entries[i].factor > entries[j].factor
+ })
+ for i, e := range entries {
+ items[i] = e.item
+ }
+}
+
+func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig {
+ result := make([]ProviderConfig, 0, capacity)
+ for _, item := range withQuota {
+ result = append(result, item.cfg)
+ }
+ for _, item := range withoutQuota {
+ result = append(result, item.cfg)
+ }
+ return result
+}
+
+func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
+ if cfg.APIKey == "" {
+ return false
+ }
+ if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
+ slog.Info("websearch: provider expired, skipping",
+ "provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
+ return false
+ }
+ return true
+}
+
+// --- Proxy availability tracking ---
+
+// markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL.
+func (m *Manager) markProxyUnavailable(ctx context.Context, cfg ProviderConfig, accountProxyURL string) {
+ proxyID := resolveProxyID(cfg, accountProxyURL)
+ if proxyID <= 0 || m.redis == nil {
+ return
+ }
+ key := fmt.Sprintf(proxyUnavailableKey, proxyID)
+ if err := m.redis.Set(ctx, key, "1", proxyUnavailableTTL).Err(); err != nil {
+ slog.Warn("websearch: failed to mark proxy unavailable",
+ "proxy_id", proxyID, "error", err)
+ }
+}
+
+// isProxyAvailable checks whether a proxy is currently marked as unavailable.
+func (m *Manager) isProxyAvailable(ctx context.Context, proxyID int64) bool {
+ if m.redis == nil || proxyID <= 0 {
+ return true
+ }
+ key := fmt.Sprintf(proxyUnavailableKey, proxyID)
+ val, err := m.redis.Get(ctx, key).Result()
+ if err != nil {
+ return true // Redis error → assume available
+ }
+ return val == ""
+}
+
+// resolveProxyID determines the effective proxy ID for a provider+account combination.
+func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 {
+ if accountProxyURL != "" {
+ return 0 // account proxy has no ID in provider config
+ }
+ return cfg.ProxyID
+}
+
+// isProxyError checks whether the error is likely caused by proxy or network connectivity
+// (as opposed to an API-level error from the search provider).
+func isProxyError(err error) bool {
+ if err == nil {
+ return false
+ }
+ // Network-level errors (timeout, connection refused, DNS failure)
+ var netErr net.Error
+ if errors.As(err, &netErr) {
+ return true
+ }
+ var opErr *net.OpError
+ if errors.As(err, &opErr) {
+ return true
+ }
+ // TLS handshake failures (often caused by proxy intercepting/blocking)
+ var tlsErr *tls.RecordHeaderError
+ if errors.As(err, &tlsErr) {
+ return true
+ }
+ // String-based detection for wrapped errors
+ msg := strings.ToLower(err.Error())
+ return strings.Contains(msg, "proxy") ||
+ strings.Contains(msg, "socks") ||
+ strings.Contains(msg, "connection refused") ||
+ strings.Contains(msg, "no such host") ||
+ strings.Contains(msg, "i/o timeout") ||
+ strings.Contains(msg, "tls handshake") ||
+ strings.Contains(msg, "certificate")
+}
+
+// --- Quota management ---
+
+func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
+ if cfg.QuotaLimit <= 0 {
+ return true, false
+ }
+ if m.redis == nil {
+ slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type)
+ return true, false
+ }
+ key := quotaRedisKey(cfg.Type)
+ ttlSec := int(quotaTTLFromSubscription(cfg.SubscribedAt).Seconds())
+ newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
+ if err != nil {
+ slog.Warn("websearch: quota Lua INCR failed, allowing request",
+ "provider", cfg.Type, "error", err)
+ return true, false
+ }
+ if newVal > cfg.QuotaLimit {
+ if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
+ slog.Warn("websearch: quota over-limit DECR failed",
+ "provider", cfg.Type, "error", decrErr)
+ }
+ slog.Info("websearch: provider quota exhausted",
+ "provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
+ return false, false
+ }
+ return true, true
+}
+
+func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
+ if cfg.QuotaLimit <= 0 || m.redis == nil {
+ return
+ }
+ key := quotaRedisKey(cfg.Type)
+ if err := m.redis.Decr(ctx, key).Err(); err != nil {
+ slog.Warn("websearch: quota rollback DECR failed",
+ "provider", cfg.Type, "error", err)
+ }
+}
+
+// --- Search execution ---
+
+// TestSearch executes a search using the first available provider without reserving quota.
+// Intended for admin test functionality only.
+func (m *Manager) TestSearch(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
+ if strings.TrimSpace(req.Query) == "" {
+ return nil, "", fmt.Errorf("websearch: empty search query")
+ }
+ for _, cfg := range m.configs {
+ if !m.isProviderAvailable(cfg) {
+ continue
+ }
+ resp, err := m.executeSearch(ctx, cfg, req)
+ if err != nil {
+ continue
+ }
+ return resp, cfg.Type, nil
+ }
+ return nil, "", fmt.Errorf("websearch: no available provider")
+}
+
+func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
+ proxyURL := cfg.ProxyURL
+ if req.ProxyURL != "" {
+ proxyURL = req.ProxyURL
+ }
+ client, err := m.getOrCreateHTTPClient(proxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("websearch: %w", err)
+ }
+ provider := m.buildProvider(cfg, client)
+ return provider.Search(ctx, req)
+}
+
+// --- HTTP client cache ---
+
+func (m *Manager) getOrCreateHTTPClient(proxyURL string) (*http.Client, error) {
+ m.clientMu.Lock()
+ defer m.clientMu.Unlock()
+
+ if c, ok := m.clientCache[proxyURL]; ok {
+ return c, nil
+ }
+ if len(m.clientCache) >= maxCachedClients {
+ m.clientCache = make(map[string]*http.Client)
+ }
+ c, err := newHTTPClient(proxyURL)
+ if err != nil {
+ return nil, err
+ }
+ m.clientCache[proxyURL] = c
+ return c, nil
+}
+
+// newHTTPClient creates an HTTP client with proper timeout settings.
+// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support
+// (HTTP/HTTPS/SOCKS5/SOCKS5H).
+// Returns error if proxyURL is invalid — never falls back to direct connection.
+func newHTTPClient(proxyURL string) (*http.Client, error) {
+ transport := &http.Transport{
+ TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
+ DialContext: (&net.Dialer{Timeout: proxyDialTimeout}).DialContext,
+ TLSHandshakeTimeout: proxyTLSTimeout,
+ ResponseHeaderTimeout: searchDataTimeout,
+ }
+ if proxyURL != "" {
+ parsed, err := url.Parse(proxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err)
+ }
+ if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
+ return nil, fmt.Errorf("configure proxy: %w", err)
+ }
+ }
+ return &http.Client{Transport: transport, Timeout: searchRequestTimeout}, nil
+}
+
+// GetUsage returns the current usage count for the given provider.
+func (m *Manager) GetUsage(ctx context.Context, providerType string) (int64, error) {
+ if m.redis == nil {
+ return 0, nil
+ }
+ key := quotaRedisKey(providerType)
+ val, err := m.redis.Get(ctx, key).Int64()
+ if err == redis.Nil {
+ return 0, nil
+ }
+ return val, err
+}
+
+// GetAllUsage returns usage for every configured provider.
+func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
+ result := make(map[string]int64, len(m.configs))
+ for _, cfg := range m.configs {
+ used, _ := m.GetUsage(ctx, cfg.Type)
+ result[cfg.Type] = used
+ }
+ return result
+}
+
+// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0.
+func (m *Manager) ResetUsage(ctx context.Context, providerType string) error {
+ if m.redis == nil {
+ return nil
+ }
+ key := quotaRedisKey(providerType)
+ return m.redis.Del(ctx, key).Err()
+}
+
+// --- Provider factory ---
+
+func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
+ switch cfg.Type {
+ case braveProviderName:
+ return NewBraveProvider(cfg.APIKey, client)
+ case tavilyProviderName:
+ return NewTavilyProvider(cfg.APIKey, client)
+ default:
+ slog.Warn("websearch: unknown provider type, falling back to brave",
+ "type", cfg.Type)
+ return NewBraveProvider(cfg.APIKey, client)
+ }
+}
+
+// --- Redis key helpers ---
+
+func quotaRedisKey(providerType string) string {
+ return quotaKeyPrefix + providerType
+}
+
+// quotaTTLFromSubscription calculates the TTL for the quota counter based on
+// the provider's subscription start date. Quota resets monthly from that date.
+// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh).
+func quotaTTLFromSubscription(subscribedAt *int64) time.Duration {
+ if subscribedAt == nil || *subscribedAt == 0 {
+ return defaultQuotaTTL
+ }
+ next := nextMonthlyReset(time.Unix(*subscribedAt, 0).UTC())
+ ttl := time.Until(next) + quotaTTLBuffer
+ if ttl <= quotaTTLBuffer {
+ // Already past the reset — next cycle
+ ttl = defaultQuotaTTL
+ }
+ return ttl
+}
+
+// nextMonthlyReset returns the next monthly reset time based on the subscription start date.
+// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc.
+// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3).
+func nextMonthlyReset(subscribedAt time.Time) time.Time {
+ now := time.Now().UTC()
+ if subscribedAt.IsZero() {
+ return now.AddDate(0, 1, 0)
+ }
+ months := (now.Year()-subscribedAt.Year())*12 + int(now.Month()-subscribedAt.Month())
+ if months < 0 {
+ months = 0
+ }
+ candidate := addMonthsClamped(subscribedAt, months)
+ if candidate.After(now) {
+ return candidate
+ }
+ return addMonthsClamped(subscribedAt, months+1)
+}
+
+// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month.
+// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3).
+func addMonthsClamped(t time.Time, months int) time.Time {
+ y, m, d := t.Date()
+ targetMonth := time.Month(int(m) + months)
+ targetYear := y + int(targetMonth-1)/12
+ targetMonth = (targetMonth-1)%12 + 1
+ // Last day of the target month
+ lastDay := time.Date(targetYear, targetMonth+1, 0, 0, 0, 0, 0, time.UTC).Day()
+ if d > lastDay {
+ d = lastDay
+ }
+ return time.Date(targetYear, targetMonth, d, 0, 0, 0, 0, time.UTC)
+}
diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go
new file mode 100644
index 00000000..a4413417
--- /dev/null
+++ b/backend/internal/pkg/websearch/manager_test.go
@@ -0,0 +1,323 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewManager_PreservesOrder(t *testing.T) {
+ configs := []ProviderConfig{
+ {Type: "brave", APIKey: "k3"},
+ {Type: "tavily", APIKey: "k1"},
+ }
+ m := NewManager(configs, nil)
+ require.Equal(t, "brave", m.configs[0].Type)
+ require.Equal(t, "tavily", m.configs[1].Type)
+}
+
+func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
+ m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
+ require.ErrorContains(t, err, "empty search query")
+
+ _, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
+ require.ErrorContains(t, err, "empty search query")
+}
+
+func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
+ m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "no available provider")
+}
+
+func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
+ past := time.Now().Add(-1 * time.Hour).Unix()
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k", ExpiresAt: &past},
+ }, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "no available provider")
+}
+
+func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) {
+ srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srvBrave.Close()
+
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srvBrave.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k1"},
+ {Type: "tavily", APIKey: "k2"},
+ }, nil)
+ m.clientCache[srvBrave.URL] = srvBrave.Client()
+ m.clientCache[""] = srvBrave.Client()
+
+ resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Equal(t, "brave", providerName)
+ require.Len(t, resp.Results, 1)
+ require.Equal(t, "from brave", resp.Results[0].Snippet)
+}
+
+func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k", QuotaLimit: 100},
+ }, nil)
+ m.clientCache[""] = srv.Client()
+
+ resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Len(t, resp.Results, 1)
+}
+
+func TestManager_GetUsage_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ used, err := m.GetUsage(context.Background(), "brave")
+ require.NoError(t, err)
+ require.Equal(t, int64(0), used)
+}
+
+func TestManager_GetAllUsage_NilRedis(t *testing.T) {
+ m := NewManager([]ProviderConfig{
+ {Type: "brave"},
+ }, nil)
+ usage := m.GetAllUsage(context.Background())
+ require.Equal(t, int64(0), usage["brave"])
+}
+
+// --- Quota TTL from subscription ---
+
+func TestQuotaTTLFromSubscription_NilSubscription(t *testing.T) {
+ ttl := quotaTTLFromSubscription(nil)
+ require.Equal(t, defaultQuotaTTL, ttl)
+}
+
+func TestQuotaTTLFromSubscription_ZeroSubscription(t *testing.T) {
+ zero := int64(0)
+ ttl := quotaTTLFromSubscription(&zero)
+ require.Equal(t, defaultQuotaTTL, ttl)
+}
+
+func TestQuotaTTLFromSubscription_ValidSubscription(t *testing.T) {
+ // Subscribed 10 days ago — next reset in ~20 days
+ sub := time.Now().Add(-10 * 24 * time.Hour).Unix()
+ ttl := quotaTTLFromSubscription(&sub)
+ require.Greater(t, ttl, 15*24*time.Hour) // at least 15 days
+ require.Less(t, ttl, 25*24*time.Hour+quotaTTLBuffer)
+}
+
+func TestNextMonthlyReset_SubscribedRecentPast(t *testing.T) {
+ // Subscribed on the 10th of this month (always valid day)
+ now := time.Now().UTC()
+ sub := time.Date(now.Year(), now.Month(), 10, 0, 0, 0, 0, time.UTC)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(now) || next.Equal(now), "next reset should be in the future or now")
+ require.True(t, next.Before(now.AddDate(0, 1, 1)))
+}
+
+func TestNextMonthlyReset_SubscribedLongAgo(t *testing.T) {
+ // Subscribed 6 months ago on the 1st
+ sub := time.Now().UTC().AddDate(0, -6, 0)
+ sub = time.Date(sub.Year(), sub.Month(), 1, 0, 0, 0, 0, time.UTC)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(time.Now().UTC()))
+ // Should be within the next 31 days
+ require.True(t, next.Before(time.Now().UTC().AddDate(0, 1, 1)))
+}
+
+func TestNextMonthlyReset_FutureSubscription(t *testing.T) {
+ sub := time.Now().UTC().AddDate(0, 0, 5)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(time.Now().UTC()))
+}
+
+func TestAddMonthsClamped_Jan31ToFeb(t *testing.T) {
+ sub := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 28, next.Day()) // Feb 28 (2026 is not a leap year)
+}
+
+func TestAddMonthsClamped_Jan31ToFebLeapYear(t *testing.T) {
+ sub := time.Date(2028, 1, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 29, next.Day()) // Feb 29 (2028 is a leap year)
+}
+
+func TestAddMonthsClamped_Mar31ToApr(t *testing.T) {
+ sub := time.Date(2026, 3, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(4), next.Month())
+ require.Equal(t, 30, next.Day()) // Apr has 30 days
+}
+
+func TestAddMonthsClamped_NormalDay(t *testing.T) {
+ sub := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 15, next.Day()) // no clamping needed
+}
+
+// --- Redis key ---
+
+func TestQuotaRedisKey_Format(t *testing.T) {
+ key := quotaRedisKey("brave")
+ require.Equal(t, "websearch:quota:brave", key)
+}
+
+// --- isProviderAvailable ---
+
+func TestIsProviderAvailable_EmptyAPIKey(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: ""}))
+}
+
+func TestIsProviderAvailable_Expired(t *testing.T) {
+ m := NewManager(nil, nil)
+ past := time.Now().Add(-1 * time.Hour).Unix()
+ require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &past}))
+}
+
+func TestIsProviderAvailable_Valid(t *testing.T) {
+ m := NewManager(nil, nil)
+ future := time.Now().Add(1 * time.Hour).Unix()
+ require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &future}))
+ require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k"})) // no expiry
+}
+
+// --- resolveProxyID ---
+
+func TestResolveProxyID_AccountProxyOverrides(t *testing.T) {
+ cfg := ProviderConfig{ProxyID: 42}
+ require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080"))
+ require.Equal(t, int64(42), resolveProxyID(cfg, ""))
+}
+
+// --- isProxyError ---
+
+func TestIsProxyError_Nil(t *testing.T) {
+ require.False(t, isProxyError(nil))
+}
+
+func TestIsProxyError_ConnectionRefused(t *testing.T) {
+ require.True(t, isProxyError(fmt.Errorf("dial tcp: connection refused")))
+}
+
+func TestIsProxyError_Timeout(t *testing.T) {
+ require.True(t, isProxyError(fmt.Errorf("i/o timeout while connecting to proxy")))
+}
+
+func TestIsProxyError_SOCKS(t *testing.T) {
+ require.True(t, isProxyError(fmt.Errorf("socks connect failed")))
+}
+
+func TestIsProxyError_TLSHandshake(t *testing.T) {
+ require.True(t, isProxyError(fmt.Errorf("tls handshake timeout")))
+}
+
+func TestIsProxyError_APIError_NotProxy(t *testing.T) {
+ require.False(t, isProxyError(fmt.Errorf("API rate limit exceeded")))
+}
+
+// --- isProxyAvailable (nil Redis) ---
+
+func TestIsProxyAvailable_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.True(t, m.isProxyAvailable(context.Background(), 42))
+}
+
+func TestIsProxyAvailable_ZeroID(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.True(t, m.isProxyAvailable(context.Background(), 0))
+}
+
+// --- selectByQuotaWeight ---
+
+func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) {
+ m := NewManager(nil, nil)
+ candidates := []ProviderConfig{
+ {Type: "brave", APIKey: "k1", QuotaLimit: 0},
+ {Type: "tavily", APIKey: "k2", QuotaLimit: 100},
+ }
+ result := m.selectByQuotaWeight(context.Background(), candidates)
+ require.Len(t, result, 2)
+ require.Equal(t, "tavily", result[0].Type)
+ require.Equal(t, "brave", result[1].Type)
+}
+
+func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) {
+ m := NewManager(nil, nil)
+ candidates := []ProviderConfig{
+ {Type: "brave", APIKey: "k1", QuotaLimit: 0},
+ {Type: "tavily", APIKey: "k2", QuotaLimit: 0},
+ }
+ result := m.selectByQuotaWeight(context.Background(), candidates)
+ require.Len(t, result, 2)
+}
+
+func TestSelectByQuotaWeight_Empty(t *testing.T) {
+ m := NewManager(nil, nil)
+ result := m.selectByQuotaWeight(context.Background(), nil)
+ require.Empty(t, result)
+}
+
+// --- newHTTPClient ---
+
+func TestNewHTTPClient_NoProxy(t *testing.T) {
+ c, err := newHTTPClient("")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
+
+func TestNewHTTPClient_InvalidProxy(t *testing.T) {
+ _, err := newHTTPClient("://bad-url")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid proxy URL")
+}
+
+func TestNewHTTPClient_ValidHTTPProxy(t *testing.T) {
+ c, err := newHTTPClient("http://proxy.example.com:8080")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
+
+func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) {
+ c, err := newHTTPClient("socks5://proxy.example.com:1080")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
+
+// --- ResetUsage ---
+
+func TestManager_ResetUsage_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ err := m.ResetUsage(context.Background(), "brave")
+ require.NoError(t, err)
+}
diff --git a/backend/internal/pkg/websearch/provider.go b/backend/internal/pkg/websearch/provider.go
new file mode 100644
index 00000000..3424c056
--- /dev/null
+++ b/backend/internal/pkg/websearch/provider.go
@@ -0,0 +1,11 @@
+package websearch
+
+import "context"
+
+// Provider is the interface every search backend must implement.
+type Provider interface {
+ // Name returns the provider identifier ("brave" or "tavily").
+ Name() string
+ // Search executes a web search and returns results.
+ Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
+}
diff --git a/backend/internal/pkg/websearch/tavily.go b/backend/internal/pkg/websearch/tavily.go
new file mode 100644
index 00000000..ac4928a6
--- /dev/null
+++ b/backend/internal/pkg/websearch/tavily.go
@@ -0,0 +1,107 @@
+package websearch
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+const (
+ tavilySearchEndpoint = "https://api.tavily.com/search"
+ tavilyProviderName = "tavily"
+ tavilySearchDepthBasic = "basic"
+)
+
+// TavilyProvider implements web search via the Tavily Search API.
+type TavilyProvider struct {
+ apiKey string
+ httpClient *http.Client
+}
+
+// NewTavilyProvider creates a Tavily Search provider.
+// The caller is responsible for configuring the http.Client with proxy/timeouts.
+func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
+}
+
+func (t *TavilyProvider) Name() string { return tavilyProviderName }
+
+func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
+ maxResults := req.MaxResults
+ if maxResults <= 0 {
+ maxResults = defaultMaxResults
+ }
+
+ payload := tavilyRequest{
+ APIKey: t.apiKey,
+ Query: req.Query,
+ MaxResults: maxResults,
+ SearchDepth: tavilySearchDepthBasic,
+ }
+
+ bodyBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("tavily: encode request: %w", err)
+ }
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
+ if err != nil {
+ return nil, fmt.Errorf("tavily: build request: %w", err)
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+
+ resp, err := t.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("tavily: request failed: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
+ if err != nil {
+ return nil, fmt.Errorf("tavily: read body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
+ }
+
+ var raw tavilyResponse
+ if err := json.Unmarshal(body, &raw); err != nil {
+ return nil, fmt.Errorf("tavily: decode response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(raw.Results))
+ for _, r := range raw.Results {
+ results = append(results, SearchResult{
+ URL: r.URL,
+ Title: r.Title,
+ Snippet: r.Content,
+ })
+ }
+
+ return &SearchResponse{Results: results, Query: req.Query}, nil
+}
+
+type tavilyRequest struct {
+ APIKey string `json:"api_key"`
+ Query string `json:"query"`
+ MaxResults int `json:"max_results"`
+ SearchDepth string `json:"search_depth"`
+}
+
+type tavilyResponse struct {
+ Results []tavilyResult `json:"results"`
+}
+
+type tavilyResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Score float64 `json:"score"`
+}
diff --git a/backend/internal/pkg/websearch/tavily_test.go b/backend/internal/pkg/websearch/tavily_test.go
new file mode 100644
index 00000000..e1b6819a
--- /dev/null
+++ b/backend/internal/pkg/websearch/tavily_test.go
@@ -0,0 +1,63 @@
+package websearch
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTavilyProvider_Name(t *testing.T) {
+ p := NewTavilyProvider("key", nil)
+ require.Equal(t, "tavily", p.Name())
+}
+
+func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
+ // Verify tavilyRequest struct fields map correctly
+ req := tavilyRequest{
+ APIKey: "test-key",
+ Query: "golang",
+ MaxResults: 3,
+ SearchDepth: tavilySearchDepthBasic,
+ }
+ data, err := json.Marshal(req)
+ require.NoError(t, err)
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(data, &parsed))
+ require.Equal(t, "test-key", parsed["api_key"])
+ require.Equal(t, "golang", parsed["query"])
+ require.Equal(t, float64(3), parsed["max_results"])
+ require.Equal(t, "basic", parsed["search_depth"])
+}
+
+func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
+ rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
+ var resp tavilyResponse
+ require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
+ require.Len(t, resp.Results, 1)
+ require.Equal(t, "https://go.dev", resp.Results[0].URL)
+ require.Equal(t, "Go programming language", resp.Results[0].Content)
+ require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
+
+ // Verify mapping to SearchResult
+ results := make([]SearchResult, 0, len(resp.Results))
+ for _, r := range resp.Results {
+ results = append(results, SearchResult{
+ URL: r.URL, Title: r.Title, Snippet: r.Content,
+ })
+ }
+ require.Equal(t, "Go programming language", results[0].Snippet)
+ require.Equal(t, "", results[0].PageAge)
+}
+
+func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
+ var resp tavilyResponse
+ require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
+ require.Empty(t, resp.Results)
+}
+
+func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
+ var resp tavilyResponse
+ require.Error(t, json.Unmarshal([]byte("not json"), &resp))
+}
diff --git a/backend/internal/pkg/websearch/types.go b/backend/internal/pkg/websearch/types.go
new file mode 100644
index 00000000..bb489690
--- /dev/null
+++ b/backend/internal/pkg/websearch/types.go
@@ -0,0 +1,30 @@
+package websearch
+
+// SearchResult represents a single web search result.
+type SearchResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Snippet string `json:"snippet"`
+ PageAge string `json:"page_age,omitempty"`
+}
+
+// SearchRequest describes a web search to perform.
+type SearchRequest struct {
+ Query string
+ MaxResults int // defaults to defaultMaxResults if <= 0
+ ProxyURL string // optional HTTP proxy URL
+}
+
+// SearchResponse holds the results of a web search.
+type SearchResponse struct {
+ Results []SearchResult
+ Query string // the query that was actually executed
+}
+
+const defaultMaxResults = 5
+
+// Provider type identifiers.
+const (
+ ProviderTypeBrave = "brave"
+ ProviderTypeTavily = "tavily"
+)
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 7fd98855..38ea9bde 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -138,10 +138,17 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
+ user.FieldEmail,
+ user.FieldUsername,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
+ user.FieldBalanceNotifyEnabled,
+ user.FieldBalanceNotifyThresholdType,
+ user.FieldBalanceNotifyThreshold,
+ user.FieldBalanceNotifyExtraEmails,
+ user.FieldTotalRecharged,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -639,22 +646,31 @@ func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
- return &service.User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Notes: u.Notes,
- PasswordHash: u.PasswordHash,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- TotpSecretEncrypted: u.TotpSecretEncrypted,
- TotpEnabled: u.TotpEnabled,
- TotpEnabledAt: u.TotpEnabledAt,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ out := &service.User{
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Notes: u.Notes,
+ PasswordHash: u.PasswordHash,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ TotpSecretEncrypted: u.TotpSecretEncrypted,
+ TotpEnabled: u.TotpEnabled,
+ TotpEnabledAt: u.TotpEnabledAt,
+ BalanceNotifyEnabled: u.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: u.BalanceNotifyThreshold,
+ TotalRecharged: u.TotalRecharged,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
}
+ // Parse extra emails JSON (supports both old []string and new []NotifyEmailEntry format)
+ if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" {
+ out.BalanceNotifyExtraEmails = service.ParseNotifyEmails(u.BalanceNotifyExtraEmails)
+ }
+ return out
}
func groupEntityToService(g *dbent.Group) *service.Group {
diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go
index 49c2d8d9..2cb90aab 100644
--- a/backend/internal/repository/channel_repo.go
+++ b/backend/internal/repository/channel_repo.go
@@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
+ featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
+ if err != nil {
+ return err
+ }
err = tx.QueryRowContext(ctx,
- `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
+ `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id, created_at, updated_at`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -67,17 +71,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
}
}
+ // 设置账号统计定价规则
+ if len(channel.AccountStatsPricingRules) > 0 {
+ if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
+ return err
+ }
+ }
+
return nil
})
}
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
- var modelMappingJSON []byte
+ var modelMappingJSON, featuresConfigJSON []byte
err := r.db.QueryRowContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at
FROM channels WHERE id = $1`, id,
- ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
+ ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
@@ -85,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
@@ -98,6 +110,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
}
ch.ModelPricing = pricing
+ statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ ch.AccountStatsPricingRules = statsPricingRules
+
return ch, nil
}
@@ -107,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
+ featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
+ if err != nil {
+ return err
+ }
result, err := tx.ExecContext(ctx,
- `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
- WHERE id = $7`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
+ `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW()
+ WHERE id = $10`,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -137,6 +159,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
}
}
+ // 更新账号统计定价规则
+ if channel.AccountStatsPricingRules != nil {
+ if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
+ return err
+ }
+ }
+
return nil
})
}
@@ -187,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
- `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
+ `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
)
@@ -203,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64
for rows.Next() {
var ch service.Channel
- var modelMappingJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ var modelMappingJSON, featuresConfigJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -225,9 +255,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
if err != nil {
return nil, nil, err
}
+ statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
+ if err != nil {
+ return nil, nil, err
+ }
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
+ channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
}
}
@@ -273,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -284,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64
for rows.Next() {
var ch service.Channel
- var modelMappingJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ var modelMappingJSON, featuresConfigJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -312,9 +348,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
return nil, err
}
+ // 批量加载账号统计定价规则
+ statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
+ if err != nil {
+ return nil, err
+ }
+
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
+ channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
}
return channels, nil
@@ -456,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
return m
}
+func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
+ if len(m) == 0 {
+ return []byte("{}"), nil
+ }
+ data, err := json.Marshal(m)
+ if err != nil {
+ return nil, fmt.Errorf("marshal features_config: %w", err)
+ }
+ return data, nil
+}
+
+func unmarshalFeaturesConfig(data []byte) map[string]any {
+ if len(data) == 0 {
+ return nil
+ }
+ var m map[string]any
+ if err := json.Unmarshal(data, &m); err != nil {
+ return nil
+ }
+ return m
+}
+
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {
diff --git a/backend/internal/repository/channel_repo_account_stats_pricing.go b/backend/internal/repository/channel_repo_account_stats_pricing.go
new file mode 100644
index 00000000..9e00fed8
--- /dev/null
+++ b/backend/internal/repository/channel_repo_account_stats_pricing.go
@@ -0,0 +1,244 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+// --- 账号统计定价规则 ---
+
+// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
+func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
+ // 1. 查询规则
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
+ FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
+ pq.Array(channelIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var allRules []service.AccountStatsPricingRule
+ var ruleIDs []int64
+ for rows.Next() {
+ var rule service.AccountStatsPricingRule
+ if err := rows.Scan(
+ &rule.ID, &rule.ChannelID, &rule.Name,
+ pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
+ &rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
+ }
+ ruleIDs = append(ruleIDs, rule.ID)
+ allRules = append(allRules, rule)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
+ }
+
+ // 2. 批量加载规则的模型定价
+ pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // 3. 按 channelID 分组并关联定价
+ result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
+ for i := range allRules {
+ allRules[i].Pricing = pricingMap[allRules[i].ID]
+ result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
+ }
+
+ return result, nil
+}
+
+// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
+func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
+ if len(ruleIDs) == 0 {
+ return make(map[int64][]service.ChannelModelPricing), nil
+ }
+
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
+ cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
+ FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
+ pq.Array(ruleIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
+ for rows.Next() {
+ var p service.ChannelModelPricing
+ var ruleID int64
+ var modelsJSON []byte
+ if err := rows.Scan(
+ &p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
+ &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
+ &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan account stats model pricing: %w", err)
+ }
+ if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
+ p.Models = []string{}
+ }
+ pricingMap[ruleID] = append(pricingMap[ruleID], p)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
+ }
+
+ // Load intervals for all pricing entries.
+ var allPricingIDs []int64
+ for _, pricings := range pricingMap {
+ for _, p := range pricings {
+ allPricingIDs = append(allPricingIDs, p.ID)
+ }
+ }
+ if len(allPricingIDs) > 0 {
+ intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs)
+ if err != nil {
+ return nil, err
+ }
+ for ruleID, pricings := range pricingMap {
+ for i := range pricings {
+ pricings[i].Intervals = intervalsMap[pricings[i].ID]
+ }
+ pricingMap[ruleID] = pricings
+ }
+ }
+
+ return pricingMap, nil
+}
+
+// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
+func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
+ result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
+ if err != nil {
+ return nil, err
+ }
+ return result[channelID], nil
+}
+
+// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
+func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
+ // CASCADE 会自动删除关联的 model_pricing
+ if _, err := tx.ExecContext(ctx,
+ `DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
+ ); err != nil {
+ return fmt.Errorf("delete old account stats pricing rules: %w", err)
+ }
+
+ for i := range rules {
+ rules[i].ChannelID = channelID
+ if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
+ return fmt.Errorf("insert account stats pricing rule: %w", err)
+ }
+ }
+ return nil
+}
+
+// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
+func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
+ err := tx.QueryRowContext(ctx,
+ `INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
+ VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
+ rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
+ ).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
+ if err != nil {
+ return fmt.Errorf("insert account stats pricing rule: %w", err)
+ }
+
+ for j := range rule.Pricing {
+ if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
+func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
+ modelsJSON, err := json.Marshal(pricing.Models)
+ if err != nil {
+ return fmt.Errorf("marshal models: %w", err)
+ }
+ billingMode := pricing.BillingMode
+ if billingMode == "" {
+ billingMode = service.BillingModeToken
+ }
+ platform := pricing.Platform
+ err = tx.QueryRowContext(ctx,
+ `INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
+ ruleID, platform, modelsJSON, billingMode,
+ pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
+ pricing.ImageOutputPrice, pricing.PerRequestPrice,
+ ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
+ if err != nil {
+ return fmt.Errorf("insert account stats model pricing: %w", err)
+ }
+ // Persist intervals (mirrors channel_pricing_intervals logic).
+ for i := range pricing.Intervals {
+ iv := &pricing.Intervals[i]
+ iv.PricingID = pricing.ID
+ if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry.
+func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error {
+ return tx.QueryRowContext(ctx,
+ `INSERT INTO channel_account_stats_pricing_intervals
+ (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
+ iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
+ iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
+ iv.PerRequestPrice, iv.SortOrder,
+ ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
+}
+
+// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries.
+func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
+ if len(pricingIDs) == 0 {
+ return nil, nil
+ }
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
+ input_price, output_price, cache_write_price, cache_read_price,
+ per_request_price, sort_order, created_at, updated_at
+ FROM channel_account_stats_pricing_intervals
+ WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
+ pq.Array(pricingIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ result := make(map[int64][]service.PricingInterval)
+ for rows.Next() {
+ var iv service.PricingInterval
+ if err := rows.Scan(
+ &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
+ &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
+ &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan account stats pricing interval: %w", err)
+ }
+ result[iv.PricingID] = append(result[iv.PricingID], iv)
+ }
+ return result, rows.Err()
+}
diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go
index 8f2b8eca..96a23a8e 100644
--- a/backend/internal/repository/email_cache.go
+++ b/backend/internal/repository/email_cache.go
@@ -3,6 +3,8 @@ package repository
import (
"context"
"encoding/json"
+ "fmt"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -11,23 +13,33 @@ import (
const (
verifyCodeKeyPrefix = "verify_code:"
+ notifyVerifyKeyPrefix = "notify_verify:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
+ notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
)
// verifyCodeKey generates the Redis key for email verification code.
+// Email is lowercased for case-insensitive consistency.
func verifyCodeKey(email string) string {
- return verifyCodeKeyPrefix + email
+ return verifyCodeKeyPrefix + strings.ToLower(email)
+}
+
+// notifyVerifyKey generates the Redis key for notify email verification code.
+// Email is lowercased to prevent case-sensitive key mismatch (the business layer
+// uses strings.EqualFold for comparison).
+func notifyVerifyKey(email string) string {
+ return notifyVerifyKeyPrefix + strings.ToLower(email)
}
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
- return passwordResetKeyPrefix + email
+ return passwordResetKeyPrefix + strings.ToLower(email)
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
- return passwordResetSentAtKeyPrefix + email
+ return passwordResetSentAtKeyPrefix + strings.ToLower(email)
}
type emailCache struct {
@@ -106,3 +118,60 @@ func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email st
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}
+
+// Notify email verification code methods
+
+func (c *emailCache) GetNotifyVerifyCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
+ key := notifyVerifyKey(email)
+ val, err := c.rdb.Get(ctx, key).Result()
+ if err != nil {
+ return nil, err
+ }
+ var data service.VerificationCodeData
+ if err := json.Unmarshal([]byte(val), &data); err != nil {
+ return nil, err
+ }
+ return &data, nil
+}
+
+func (c *emailCache) SetNotifyVerifyCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
+ key := notifyVerifyKey(email)
+ val, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ return c.rdb.Set(ctx, key, val, ttl).Err()
+}
+
+func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
+ key := notifyVerifyKey(email)
+ return c.rdb.Del(ctx, key).Err()
+}
+
+// User-level rate limiting for notify email verification codes
+
+func notifyCodeUserRateKey(userID int64) string {
+ return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID)
+}
+
+func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
+ key := notifyCodeUserRateKey(userID)
+ count, err := c.rdb.Incr(ctx, key).Result()
+ if err != nil {
+ return 0, err
+ }
+ // Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE.
+ if err := c.rdb.Expire(ctx, key, window).Err(); err != nil {
+ return count, fmt.Errorf("expire notify code rate key: %w", err)
+ }
+ return count, nil
+}
+
+func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
+ key := notifyCodeUserRateKey(userID)
+ count, err := c.rdb.Get(ctx, key).Int64()
+ if err != nil {
+ return 0, err
+ }
+ return count, nil
+}
diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go
index b4c76da5..2b6edad3 100644
--- a/backend/internal/repository/usage_billing_repo.go
+++ b/backend/internal/repository/usage_billing_repo.go
@@ -113,9 +113,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
if cmd.BalanceCost > 0 {
- if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
+ newBalance, err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost)
+ if err != nil {
return err
}
+ result.NewBalance = &newBalance
}
if cmd.APIKeyQuotaCost > 0 {
@@ -133,9 +135,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
- if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
+ quotaState, err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost)
+ if err != nil {
return err
}
+ result.QuotaState = quotaState
}
return nil
@@ -169,24 +173,22 @@ func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscrip
return service.ErrSubscriptionNotFound
}
-func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
- res, err := tx.ExecContext(ctx, `
+func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) (float64, error) {
+ var newBalance float64
+ err := tx.QueryRowContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
- `, amount, userID)
+ RETURNING balance
+ `, amount, userID).Scan(&newBalance)
+ if errors.Is(err, sql.ErrNoRows) {
+ return 0, service.ErrUserNotFound
+ }
if err != nil {
- return err
+ return 0, err
}
- affected, err := res.RowsAffected()
- if err != nil {
- return err
- }
- if affected > 0 {
- return nil
- }
- return service.ErrUserNotFound
+ return newBalance, nil
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
@@ -240,7 +242,7 @@ func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKe
return nil
}
-func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
+func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) (*service.AccountQuotaState, error) {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
@@ -248,61 +250,71 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
- CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
- + '24 hours'::interval <= NOW()
+ CASE WHEN `+dailyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
- CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
- + '24 hours'::interval <= NOW()
+ CASE WHEN `+dailyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
+ || CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
+ THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
+ ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
- CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
- + '168 hours'::interval <= NOW()
+ CASE WHEN `+weeklyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
- CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
- + '168 hours'::interval <= NOW()
+ CASE WHEN `+weeklyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
+ || CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
+ THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
+ ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
- COALESCE((extra->>'quota_limit')::numeric, 0)`,
+ COALESCE((extra->>'quota_limit')::numeric, 0),
+ COALESCE((extra->>'quota_daily_used')::numeric, 0),
+ COALESCE((extra->>'quota_daily_limit')::numeric, 0),
+ COALESCE((extra->>'quota_weekly_used')::numeric, 0),
+ COALESCE((extra->>'quota_weekly_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
- return err
+ return nil, err
}
defer func() { _ = rows.Close() }()
- var newUsed, limit float64
+ var state service.AccountQuotaState
if rows.Next() {
- if err := rows.Scan(&newUsed, &limit); err != nil {
- return err
+ if err := rows.Scan(
+ &state.TotalUsed, &state.TotalLimit,
+ &state.DailyUsed, &state.DailyLimit,
+ &state.WeeklyUsed, &state.WeeklyLimit,
+ ); err != nil {
+ return nil, err
}
} else {
if err := rows.Err(); err != nil {
- return err
+ return nil, err
}
- return service.ErrAccountNotFound
+ return nil, service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
- return err
+ return nil, err
}
- if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
+ if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
- return err
+ return nil, err
}
}
- return nil
+ return &state, nil
}
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index 3ba2191e..f942a8e1 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
+ "numeric", // account_stats_cost
"timestamptz", // created_at
}
@@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
@@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
- $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
) AS (VALUES `)
@@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
)
SELECT
@@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
) AS (VALUES `)
- args := make([]any, 0, len(preparedList)*45)
+ args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
)
SELECT
@@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
@@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
- $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
modelMappingChain,
billingTier,
billingMode,
+ log.AccountStatsCost, // account_stats_cost
createdAt,
},
}
@@ -1959,7 +1969,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -1989,7 +1999,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2026,7 +2036,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2990,7 +3000,7 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
modelExpr := resolveModelDimensionExpression(source)
@@ -3358,7 +3368,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
@@ -3433,7 +3443,7 @@ type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3500,7 +3510,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3591,7 +3601,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
@@ -4069,6 +4079,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
+ accountStatsCost sql.NullFloat64
createdAt time.Time
)
@@ -4118,6 +4129,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&modelMappingChain,
&billingTier,
&billingMode,
+ &accountStatsCost,
&createdAt,
); err != nil {
return nil, err
@@ -4214,6 +4226,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
+ if accountStatsCost.Valid {
+ log.AccountStatsCost = &accountStatsCost.Float64
+ }
return log, nil
}
diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go
index b9cb6a13..acdd6e62 100644
--- a/backend/internal/repository/usage_log_repo_request_type_test.go
+++ b/backend/internal/repository/usage_log_repo_request_type_test.go
@@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
+ sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
@@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
+ sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
@@ -483,10 +485,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
- sql.NullInt64{}, // channel_id
- sql.NullString{}, // model_mapping_chain
- sql.NullString{}, // billing_tier
- sql.NullString{}, // billing_mode
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -530,10 +533,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
- sql.NullInt64{}, // channel_id
- sql.NullString{}, // model_mapping_chain
- sql.NullString{}, // billing_tier
- sql.NullString{}, // billing_mode
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -577,10 +581,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
- sql.NullInt64{}, // channel_id
- sql.NullString{}, // model_mapping_chain
- sql.NullString{}, // billing_tier
- sql.NullString{}, // billing_mode
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go
index e2471ae5..eca5313f 100644
--- a/backend/internal/repository/user_group_rate_repo.go
+++ b/backend/internal/repository/user_group_rate_repo.go
@@ -100,7 +100,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
FROM user_group_rate_multipliers ugr
- JOIN users u ON u.id = ugr.user_id
+ JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
ORDER BY ugr.user_id
`
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index d5a13607..913e1c40 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -137,7 +137,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient = r.client
}
- updated, err := txClient.User.UpdateOneID(userIn.ID).
+ updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
@@ -146,7 +146,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
- Save(ctx)
+ SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
+ SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
+ SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
+ SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
+ SetTotalRecharged(userIn.TotalRecharged)
+ if userIn.BalanceNotifyThreshold == nil {
+ updateOp = updateOp.ClearBalanceNotifyThreshold()
+ }
+ updated, err := updateOp.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
@@ -382,7 +390,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
- n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
+ update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
+ // Track cumulative recharge amount for percentage-based notifications
+ if amount > 0 {
+ update = update.AddTotalRecharged(amount)
+ }
+ n, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
@@ -549,6 +562,11 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.UpdatedAt = src.UpdatedAt
}
+// marshalExtraEmails serializes notify email entries to JSON for storage.
+func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
+ return service.MarshalNotifyEmails(entries)
+}
+
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 1a4892fa..44c3f0e4 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -58,6 +58,11 @@ func TestAPIContracts(t *testing.T) {
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z",
+ "balance_notify_enabled": false,
+ "balance_notify_threshold_type": "",
+ "balance_notify_threshold": null,
+ "balance_notify_extra_emails": null,
+ "total_recharged": 0,
"run_mode": "standard"
}
}`,
@@ -204,11 +209,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
- "claude_code_only": false,
+ "claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
- "allow_messages_dispatch": false,
"require_oauth_only": false,
"require_privacy_set": false,
"created_at": "2025-01-02T03:04:05Z",
@@ -587,26 +591,32 @@ func TestAPIContracts(t *testing.T) {
"enable_cch_signing": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
+ "web_search_emulation_enabled": false,
+ "custom_menu_items": [],
+ "custom_endpoints": [],
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
- "payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
+ "payment_enabled_types": null,
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
- "custom_menu_items": [],
- "custom_endpoints": []
+ "balance_low_notify_enabled": false,
+ "account_quota_notify_enabled": false,
+ "balance_low_notify_threshold": 0,
+ "balance_low_notify_recharge_url": "",
+ "account_quota_notify_emails": []
}
}`,
},
@@ -699,7 +709,7 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
- userService := service.NewUserService(userRepo, nil, nil)
+ userService := service.NewUserService(userRepo, nil, nil, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go
index a8034e98..023e40bb 100644
--- a/backend/internal/server/http.go
+++ b/backend/internal/server/http.go
@@ -2,12 +2,15 @@
package server
import (
+ "context"
"log"
+ "log/slog"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -56,6 +59,42 @@ func ProvideRouter(
}
}
+ // Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
+ settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) {
+ if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 {
+ service.SetWebSearchManager(nil)
+ return
+ }
+ configs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
+ for _, p := range cfg.Providers {
+ if p.APIKey == "" {
+ continue
+ }
+ pc := websearch.ProviderConfig{
+ Type: p.Type,
+ APIKey: p.APIKey,
+ QuotaLimit: derefInt64(p.QuotaLimit),
+ ExpiresAt: p.ExpiresAt,
+ }
+ if p.SubscribedAt != nil {
+ pc.SubscribedAt = p.SubscribedAt
+ }
+ if p.ProxyID != nil {
+ pc.ProxyID = *p.ProxyID
+ if u, ok := proxyURLs[*p.ProxyID]; ok {
+ pc.ProxyURL = u
+ } else {
+ // Proxy configured but not found — skip this provider to prevent direct connection.
+ slog.Warn("websearch: proxy not found for provider, skipping",
+ "provider", p.Type, "proxy_id", *p.ProxyID)
+ continue
+ }
+ }
+ configs = append(configs, pc)
+ }
+ service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
+ })
+
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
}
@@ -102,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
}
}
+
+func derefInt64(p *int64) int64 {
+ if p == nil {
+ return 0
+ }
+ return *p
+}
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index aafe4a58..ed2578c8 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
return &clone, nil
},
}
- userService := service.NewUserService(userRepo, nil, nil)
+ userService := service.NewUserService(userRepo, nil, nil, nil)
router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go
index ad9c1b5b..c483a51e 100644
--- a/backend/internal/server/middleware/jwt_auth_test.go
+++ b/backend/internal/server/middleware/jwt_auth_test.go
@@ -41,7 +41,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
- userSvc := service.NewUserService(userRepo, nil, nil)
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
r := gin.New()
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index 73210bfc..7021ab2e 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -18,6 +18,8 @@ const (
NonceTemplate = "__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
+ // StripeDomain is the domain for Stripe.js SDK
+ StripeDomain = "https://*.stripe.com"
)
// GenerateNonce generates a cryptographically secure random nonce.
@@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool {
strings.HasPrefix(path, "/responses")
}
-// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
-// This allows the application to work correctly even if the config file has an older CSP policy.
+// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
+// and Stripe.js domains. This allows the application to work correctly even if the
+// config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {
// Add nonce placeholder to script-src if not present
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
@@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string {
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
}
+ // Add Stripe.js domain to script-src and frame-src if not present
+ if !strings.Contains(policy, "stripe.com") {
+ policy = addToDirective(policy, "script-src", StripeDomain)
+ policy = addToDirective(policy, "frame-src", StripeDomain)
+ }
+
return policy
}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index b921da95..9af0fd8e 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -407,6 +407,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
+ // Web Search 模拟配置
+ adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
+ adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
+ adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation)
+ adminSettings.POST("/web-search-emulation/reset-usage", h.Admin.Setting.ResetWebSearchUsage)
}
}
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
index 6bf04679..23bd58ad 100644
--- a/backend/internal/server/routes/payment.go
+++ b/backend/internal/server/routes/payment.go
@@ -39,6 +39,7 @@ func RegisterPaymentRoutes(
orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
+ orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders)
}
}
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index c3b82742..d004f8b4 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -26,6 +26,15 @@ func RegisterUserRoutes(
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+ // 通知邮箱管理
+ notifyEmail := user.Group("/notify-email")
+ {
+ notifyEmail.POST("/send-code", h.User.SendNotifyEmailCode)
+ notifyEmail.POST("/verify", h.User.VerifyNotifyEmail)
+ notifyEmail.PUT("/toggle", h.User.ToggleNotifyEmail)
+ notifyEmail.DELETE("", h.User.RemoveNotifyEmail)
+ }
+
// TOTP 双因素认证
totp := user.Group("/totp")
{
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 512195e3..52db3073 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"hash/fnv"
+ "log/slog"
"reflect"
"sort"
"strconv"
@@ -969,7 +970,7 @@ func (a *Account) IsOveragesEnabled() bool {
return false
}
-// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
+// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。
//
// 新字段:accounts.extra.openai_passthrough。
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
@@ -1133,7 +1134,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return resolvedDefault
}
-// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
+// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
// 字段:accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
@@ -1158,7 +1159,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
}
-// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
+// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。
// 字段:accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
@@ -1169,7 +1170,42 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled
}
-// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
+// WebSearch 模拟三态常量
+const (
+ WebSearchModeDefault = "default" // 跟随渠道配置
+ WebSearchModeEnabled = "enabled" // 强制开启
+ WebSearchModeDisabled = "disabled" // 强制关闭
+)
+
+// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
+// 三态:default(跟随渠道)/ enabled(强制开启)/ disabled(强制关闭)。
+// 兼容旧 bool 值:true→enabled, false→default(并记录 debug 日志)。
+func (a *Account) GetWebSearchEmulationMode() string {
+ if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
+ return WebSearchModeDefault
+ }
+ raw := a.Extra[featureKeyWebSearchEmulation]
+ // Tolerant: legacy bool values (pre-migration or stale writes)
+ if b, ok := raw.(bool); ok {
+ slog.Debug("legacy bool web_search_emulation value", "account_id", a.ID, "value", b)
+ if b {
+ return WebSearchModeEnabled
+ }
+ return WebSearchModeDefault
+ }
+ mode, ok := raw.(string)
+ if !ok {
+ return WebSearchModeDefault
+ }
+ switch mode {
+ case WebSearchModeEnabled, WebSearchModeDisabled:
+ return mode
+ default:
+ return WebSearchModeDefault
+ }
+}
+
+// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
// 字段:accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool {
@@ -1395,6 +1431,19 @@ func (a *Account) getExtraTime(key string) time.Time {
return time.Time{}
}
+// getExtraBool 从 Extra 中读取指定 key 的 bool 值
+func (a *Account) getExtraBool(key string) bool {
+ if a.Extra == nil {
+ return false
+ }
+ if v, ok := a.Extra[key]; ok {
+ if b, ok := v.(bool); ok {
+ return b
+ }
+ }
+ return false
+}
+
// getExtraString 从 Extra 中读取指定 key 的字符串值
func (a *Account) getExtraString(key string) string {
if a.Extra == nil {
@@ -1408,6 +1457,14 @@ func (a *Account) getExtraString(key string) string {
return ""
}
+// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal
+func (a *Account) getExtraStringDefault(key, defaultVal string) string {
+ if v := a.getExtraString(key); v != "" {
+ return v
+ }
+ return defaultVal
+}
+
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func (a *Account) getExtraInt(key string) int {
if a.Extra == nil {
@@ -1464,6 +1521,62 @@ func (a *Account) GetQuotaResetTimezone() string {
return "UTC"
}
+// --- Quota Notification Getters ---
+
+// QuotaNotifyConfig returns the notify configuration for a given quota dimension.
+// dim must be one of quotaDimDaily, quotaDimWeekly, quotaDimTotal.
+func (a *Account) QuotaNotifyConfig(dim string) (enabled bool, threshold float64, thresholdType string) {
+ enabled = a.getExtraBool("quota_notify_" + dim + "_enabled")
+ threshold = a.getExtraFloat64("quota_notify_" + dim + "_threshold")
+ thresholdType = a.getExtraStringDefault("quota_notify_"+dim+"_threshold_type", thresholdTypeFixed)
+ return
+}
+
+func (a *Account) GetQuotaNotifyDailyEnabled() bool {
+ e, _, _ := a.QuotaNotifyConfig(quotaDimDaily)
+ return e
+}
+
+func (a *Account) GetQuotaNotifyDailyThreshold() float64 {
+ _, t, _ := a.QuotaNotifyConfig(quotaDimDaily)
+ return t
+}
+
+func (a *Account) GetQuotaNotifyDailyThresholdType() string {
+ _, _, tt := a.QuotaNotifyConfig(quotaDimDaily)
+ return tt
+}
+
+func (a *Account) GetQuotaNotifyWeeklyEnabled() bool {
+ e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly)
+ return e
+}
+
+func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 {
+ _, t, _ := a.QuotaNotifyConfig(quotaDimWeekly)
+ return t
+}
+
+func (a *Account) GetQuotaNotifyWeeklyThresholdType() string {
+ _, _, tt := a.QuotaNotifyConfig(quotaDimWeekly)
+ return tt
+}
+
+func (a *Account) GetQuotaNotifyTotalEnabled() bool {
+ e, _, _ := a.QuotaNotifyConfig(quotaDimTotal)
+ return e
+}
+
+func (a *Account) GetQuotaNotifyTotalThreshold() float64 {
+ _, t, _ := a.QuotaNotifyConfig(quotaDimTotal)
+ return t
+}
+
+func (a *Account) GetQuotaNotifyTotalThresholdType() string {
+ _, _, tt := a.QuotaNotifyConfig(quotaDimTotal)
+ return tt
+}
+
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go
new file mode 100644
index 00000000..90ff450f
--- /dev/null
+++ b/backend/internal/service/account_stats_pricing.go
@@ -0,0 +1,236 @@
+package service
+
+import (
+ "context"
+ "strings"
+)
+
+// resolveAccountStatsCost 计算账号统计定价费用。
+// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
+//
+// 优先级(先命中为准):
+// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
+// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost)
+// 3. 模型定价文件(LiteLLM)中上游模型的默认价格
+// 4. nil → 走默认公式(total_cost × account_rate_multiplier)
+//
+// upstreamModel 是最终发往上游的模型 ID。
+// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
+func resolveAccountStatsCost(
+ ctx context.Context,
+ channelService *ChannelService,
+ billingService *BillingService,
+ accountID int64,
+ groupID int64,
+ upstreamModel string,
+ tokens UsageTokens,
+ requestCount int,
+ totalCost float64,
+) *float64 {
+ if channelService == nil || upstreamModel == "" {
+ return nil
+ }
+ channel, err := channelService.GetChannelForGroup(ctx, groupID)
+ if err != nil || channel == nil {
+ return nil
+ }
+
+ platform := channelService.GetGroupPlatform(ctx, groupID)
+
+ // 优先级 1:自定义规则(始终尝试)
+ if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
+ return cost
+ }
+
+ // 优先级 2:渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
+ if channel.ApplyPricingToAccountStats {
+ cost := totalCost
+ if cost <= 0 {
+ return nil
+ }
+ return &cost
+ }
+
+ // 优先级 3:模型定价文件(LiteLLM)默认价格
+ if billingService != nil {
+ return tryModelFilePricing(billingService, upstreamModel, tokens)
+ }
+
+ return nil
+}
+
+// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。
+func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
+ pricing, err := billingService.GetModelPricing(model)
+ if err != nil || pricing == nil {
+ return nil
+ }
+ cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
+ float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
+ float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
+ float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
+ float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
+ if cost <= 0 {
+ return nil
+ }
+ return &cost
+}
+
+// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
+func tryCustomRules(
+ channel *Channel, accountID, groupID int64,
+ platform, model string, tokens UsageTokens, requestCount int,
+) *float64 {
+ modelLower := strings.ToLower(model)
+ for _, rule := range channel.AccountStatsPricingRules {
+ if !matchAccountStatsRule(&rule, accountID, groupID) {
+ continue
+ }
+ pricing := findPricingForModel(rule.Pricing, platform, modelLower)
+ if pricing == nil {
+ continue // 规则匹配但模型不在规则定价中,继续下一条
+ }
+ return calculateStatsCost(pricing, tokens, requestCount)
+ }
+ return nil
+}
+
+// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
+// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
+// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
+func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
+ if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
+ return false
+ }
+ for _, id := range rule.AccountIDs {
+ if id == accountID {
+ return true
+ }
+ }
+ for _, id := range rule.GroupIDs {
+ if id == groupID {
+ return true
+ }
+ }
+ return false
+}
+
+// findPricingForModel 在定价列表中查找匹配的模型定价。
+// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
+func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
+ // 精确匹配优先
+ for i := range pricingList {
+ p := &pricingList[i]
+ if !isPlatformMatch(platform, p.Platform) {
+ continue
+ }
+ for _, m := range p.Models {
+ if strings.ToLower(m) == modelLower {
+ return p
+ }
+ }
+ }
+ // 通配符匹配:按配置顺序,先匹配先使用
+ for i := range pricingList {
+ p := &pricingList[i]
+ if !isPlatformMatch(platform, p.Platform) {
+ continue
+ }
+ for _, m := range p.Models {
+ ml := strings.ToLower(m)
+ if !strings.HasSuffix(ml, "*") {
+ continue
+ }
+ prefix := strings.TrimSuffix(ml, "*")
+ if strings.HasPrefix(modelLower, prefix) {
+ return p
+ }
+ }
+ }
+ return nil
+}
+
+// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
+func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
+ if queryPlatform == "" || pricingPlatform == "" {
+ return true
+ }
+ return queryPlatform == pricingPlatform
+}
+
+// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
+func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
+ if pricing == nil {
+ return nil
+ }
+ switch pricing.BillingMode {
+ case BillingModePerRequest, BillingModeImage:
+ return calculatePerRequestStatsCost(pricing, requestCount)
+ default:
+ return calculateTokenStatsCost(pricing, tokens)
+ }
+}
+
+// calculatePerRequestStatsCost 按次/图片计费。
+func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
+ if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
+ return nil
+ }
+ cost := *pricing.PerRequestPrice * float64(requestCount)
+ return &cost
+}
+
+// calculateTokenStatsCost Token 计费。
+// If the pricing has intervals, find the matching interval by total token count
+// and use its prices instead of the flat pricing fields.
+func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
+ p := pricing
+ if len(pricing.Intervals) > 0 {
+ totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
+ if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
+ p = &ChannelModelPricing{
+ InputPrice: iv.InputPrice,
+ OutputPrice: iv.OutputPrice,
+ CacheWritePrice: iv.CacheWritePrice,
+ CacheReadPrice: iv.CacheReadPrice,
+ PerRequestPrice: iv.PerRequestPrice,
+ }
+ }
+ }
+ deref := func(ptr *float64) float64 {
+ if ptr == nil {
+ return 0
+ }
+ return *ptr
+ }
+ cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
+ float64(tokens.OutputTokens)*deref(p.OutputPrice) +
+ float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
+ float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
+ float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
+ if cost <= 0 {
+ return nil
+ }
+ return &cost
+}
+
+// applyAccountStatsCost resolves the account stats cost for a usage log entry.
+// It resolves the upstream model (falling back to the requested model) and calls
+// the 4-level priority chain via resolveAccountStatsCost.
+func applyAccountStatsCost(
+ ctx context.Context,
+ usageLog *UsageLog,
+ cs *ChannelService, bs *BillingService,
+ accountID int64, groupID int64,
+ upstreamModel, requestedModel string,
+ tokens UsageTokens,
+ totalCost float64,
+) {
+ model := upstreamModel
+ if model == "" {
+ model = requestedModel
+ }
+ usageLog.AccountStatsCost = resolveAccountStatsCost(
+ ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
+ )
+}
diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go
new file mode 100644
index 00000000..36e5eb74
--- /dev/null
+++ b/backend/internal/service/account_stats_pricing_test.go
@@ -0,0 +1,771 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// matchAccountStatsRule
+// ---------------------------------------------------------------------------
+
+func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{}
+ require.False(t, matchAccountStatsRule(rule, 1, 10))
+}
+
+func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
+ require.True(t, matchAccountStatsRule(rule, 2, 999))
+}
+
+func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
+ require.True(t, matchAccountStatsRule(rule, 999, 20))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.True(t, matchAccountStatsRule(rule, 2, 999))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.True(t, matchAccountStatsRule(rule, 999, 10))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.False(t, matchAccountStatsRule(rule, 999, 999))
+}
+
+// ---------------------------------------------------------------------------
+// findPricingForModel
+// ---------------------------------------------------------------------------
+
+func TestFindPricingForModel(t *testing.T) {
+ exactPricing := ChannelModelPricing{
+ ID: 1,
+ Models: []string{"claude-opus-4"},
+ }
+ wildcardPricing := ChannelModelPricing{
+ ID: 2,
+ Models: []string{"claude-*"},
+ }
+ platformPricing := ChannelModelPricing{
+ ID: 3,
+ Platform: "openai",
+ Models: []string{"gpt-4o"},
+ }
+ emptyPlatformPricing := ChannelModelPricing{
+ ID: 4,
+ Models: []string{"gemini-2.5-pro"},
+ }
+
+ tests := []struct {
+ name string
+ list []ChannelModelPricing
+ platform string
+ model string
+ wantID int64
+ wantNil bool
+ }{
+ {
+ name: "exact match",
+ list: []ChannelModelPricing{exactPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 1,
+ },
+ {
+ name: "exact match case insensitive",
+ list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
+ platform: "",
+ model: "claude-opus-4",
+ wantID: 5,
+ },
+ {
+ name: "wildcard match",
+ list: []ChannelModelPricing{wildcardPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 2,
+ },
+ {
+ name: "exact match takes priority over wildcard",
+ list: []ChannelModelPricing{wildcardPricing, exactPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 1,
+ },
+ {
+ name: "platform mismatch skipped",
+ list: []ChannelModelPricing{platformPricing},
+ platform: "anthropic",
+ model: "gpt-4o",
+ wantNil: true,
+ },
+ {
+ name: "empty platform in pricing matches any",
+ list: []ChannelModelPricing{emptyPlatformPricing},
+ platform: "gemini",
+ model: "gemini-2.5-pro",
+ wantID: 4,
+ },
+ {
+ name: "empty platform in query matches any pricing platform",
+ list: []ChannelModelPricing{platformPricing},
+ platform: "",
+ model: "gpt-4o",
+ wantID: 3,
+ },
+ {
+ name: "no match at all",
+ list: []ChannelModelPricing{exactPricing, wildcardPricing},
+ platform: "anthropic",
+ model: "gpt-4o",
+ wantNil: true,
+ },
+ {
+ name: "empty list returns nil",
+ list: nil,
+ model: "claude-opus-4",
+ wantNil: true,
+ },
+ {
+ name: "wildcard matches by config order (first match wins)",
+ list: []ChannelModelPricing{
+ {ID: 10, Models: []string{"claude-*"}},
+ {ID: 11, Models: []string{"claude-opus-*"}},
+ },
+ platform: "",
+ model: "claude-opus-4",
+ wantID: 10, // config order: "claude-*" is first and matches, so it wins
+ },
+ {
+ name: "shorter wildcard used when longer does not match",
+ list: []ChannelModelPricing{
+ {ID: 10, Models: []string{"claude-*"}},
+ {ID: 11, Models: []string{"claude-opus-*"}},
+ },
+ platform: "",
+ model: "claude-sonnet-4",
+ wantID: 10, // only "claude-*" matches
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := findPricingForModel(tt.list, tt.platform, tt.model)
+ if tt.wantNil {
+ require.Nil(t, result)
+ return
+ }
+ require.NotNil(t, result)
+ require.Equal(t, tt.wantID, result.ID)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// calculateStatsCost
+// ---------------------------------------------------------------------------
+
+func TestCalculateStatsCost_NilPricing(t *testing.T) {
+ result := calculateStatsCost(nil, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_TokenBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ CacheWritePrice: testPtrFloat64(0.003),
+ CacheReadPrice: testPtrFloat64(0.0005),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ CacheReadTokens: 300,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
+ // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
+ require.InDelta(t, 0.95, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ ImageOutputPrice: testPtrFloat64(0.01),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ ImageOutputTokens: 10,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
+ require.InDelta(t, 0.3, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ // OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // Only input contributes: 100*0.001 = 0.1
+ require.InDelta(t, 0.1, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{} // all zeros
+ result := calculateStatsCost(pricing, tokens, 1)
+ // totalCost == 0 → returns nil (does not override, falls back to default formula)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ PerRequestPrice: testPtrFloat64(0.05),
+ }
+ tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
+ result := calculateStatsCost(pricing, tokens, 3)
+ require.NotNil(t, result)
+ // 0.05 * 3 = 0.15
+ require.InDelta(t, 0.15, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ // PerRequestPrice is nil
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ PerRequestPrice: testPtrFloat64(0),
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ // price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_ImageBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ PerRequestPrice: testPtrFloat64(0.10),
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 2)
+ require.NotNil(t, result)
+ // 0.10 * 2 = 0.20
+ require.InDelta(t, 0.20, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ // PerRequestPrice is nil
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
+ // BillingMode is empty string (default) → falls into token billing
+ pricing := &ChannelModelPricing{
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+// ---------------------------------------------------------------------------
+// tryCustomRules — 多规则顺序测试
+// ---------------------------------------------------------------------------
+
+func TestTryCustomRules_FirstMatchWins(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
+ },
+ },
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ // 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
+ require.InDelta(t, 2.0, *result, 1e-12)
+}
+
+func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ AccountIDs: []int64{888}, // 不匹配
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
+ },
+ },
+ {
+ GroupIDs: []int64{1}, // 匹配
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ // 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
+ require.InDelta(t, 5.0, *result, 1e-12)
+}
+
+func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ AccountIDs: []int64{888},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
+ require.Nil(t, result) // 账号和分组都不匹配
+}
+
+func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
+ },
+ },
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
+}
+
+// ---------------------------------------------------------------------------
+// tryModelFilePricing
+// ---------------------------------------------------------------------------
+
+// newTestBillingServiceWithPrices creates a BillingService with pre-populated
+// fallback prices for testing. No config or pricing service is needed.
+// The key must match what getFallbackPricing resolves to for a given model name.
+// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4".
+func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService {
+ return &BillingService{
+ fallbackPrices: prices,
+ }
+}
+
+func TestTryModelFilePricing_Success(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ },
+ })
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+func TestTryModelFilePricing_PricingNotFound(t *testing.T) {
+ // "nonexistent-model" does not match any fallback pattern
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryModelFilePricing(bs, "nonexistent-model", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_NilFallback(t *testing.T) {
+ // getFallbackPricing returns nil when key maps to nil
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": nil,
+ })
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_ZeroCost(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ },
+ })
+ tokens := UsageTokens{} // all zero tokens → cost = 0 → nil
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_WithImageOutput(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ ImageOutputPricePerToken: 0.01,
+ },
+ })
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ ImageOutputTokens: 10,
+ }
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
+ require.InDelta(t, 0.3, *result, 1e-12)
+}
+
+func TestTryModelFilePricing_WithCacheTokens(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ CacheCreationPricePerToken: 0.003,
+ CacheReadPricePerToken: 0.0005,
+ },
+ })
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ CacheReadTokens: 300,
+ }
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
+ // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
+ require.InDelta(t, 0.95, *result, 1e-12)
+}
+
+// ---------------------------------------------------------------------------
+// resolveAccountStatsCost — integration tests covering the 4-level priority chain
+// ---------------------------------------------------------------------------
+
+func TestResolveAccountStatsCost_NilChannelService(t *testing.T) {
+ result := resolveAccountStatsCost(
+ context.Background(),
+ nil, // channelService is nil
+ newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
+ 1, 1, "claude-sonnet-4",
+ UsageTokens{InputTokens: 100}, 1, 0.5,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) {
+ cs := newTestChannelServiceForStats(t, &Channel{
+ ID: 1,
+ Status: StatusActive,
+ }, 1, "")
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs,
+ newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
+ 1, 1, "", // empty upstream model
+ UsageTokens{InputTokens: 100}, 1, 0.5,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) {
+ // Group 99 is NOT in the cache, so GetChannelForGroup returns nil
+ cs := newTestChannelServiceForStats(t, &Channel{
+ ID: 1,
+ Status: StatusActive,
+ }, 1, "")
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs,
+ newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
+ 1, 99, "claude-sonnet-4", // groupID 99 has no channel
+ UsageTokens{InputTokens: 100}, 1, 0.5,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{10},
+ Pricing: []ChannelModelPricing{
+ {
+ ID: 100,
+ Models: []string{"claude-sonnet-4"},
+ InputPrice: testPtrFloat64(0.01),
+ OutputPrice: testPtrFloat64(0.02),
+ },
+ },
+ },
+ },
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil, // billingService not needed when custom rule hits
+ 1, 10, "claude-sonnet-4",
+ tokens, 1, 999.0, // totalCost ignored because custom rule hits
+ )
+ require.NotNil(t, result)
+ // 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0
+ require.InDelta(t, 2.0, *result, 1e-12)
+}
+
+func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: true,
+ // No custom rules
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil,
+ 1, 10, "claude-sonnet-4",
+ tokens, 1, 0.75, // totalCost = 0.75
+ )
+ require.NotNil(t, result)
+ require.InDelta(t, 0.75, *result, 1e-12)
+}
+
+func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: true,
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil,
+ 1, 10, "claude-sonnet-4",
+ UsageTokens{}, 1, 0.0, // totalCost = 0
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: false, // not enabled
+ // No custom rules
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ },
+ })
+
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, bs,
+ 1, 10, "claude-sonnet-4",
+ tokens, 1, 999.0, // totalCost ignored
+ )
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: false,
+ // No custom rules
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ // BillingService with no pricing for the model
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
+
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, bs,
+ 1, 10, "totally-unknown-model",
+ tokens, 1, 0.0,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: false,
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil, // billingService is nil
+ 1, 10, "claude-sonnet-4",
+ UsageTokens{InputTokens: 100}, 1, 0.0,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) {
+ // Both custom rule and ApplyPricingToAccountStats are configured;
+ // custom rule should take precedence.
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: true,
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{10},
+ Pricing: []ChannelModelPricing{
+ {
+ ID: 100,
+ Models: []string{"claude-sonnet-4"},
+ InputPrice: testPtrFloat64(0.05),
+ },
+ },
+ },
+ },
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ tokens := UsageTokens{InputTokens: 100}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil,
+ 1, 10, "claude-sonnet-4",
+ tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins)
+ )
+ require.NotNil(t, result)
+ // Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost)
+ require.InDelta(t, 5.0, *result, 1e-12)
+}
+
+// ---------------------------------------------------------------------------
+// helpers for resolveAccountStatsCost tests
+// ---------------------------------------------------------------------------
+
+// newTestChannelServiceForStats creates a ChannelService with a single channel
+// mapped to the given groupID, suitable for resolveAccountStatsCost tests.
+func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService {
+ t.Helper()
+ cache := newEmptyChannelCache()
+ cache.channelByGroupID[groupID] = channel
+ cache.groupPlatform[groupID] = platform
+ cs := &ChannelService{}
+ cache.loadedAt = time.Now()
+ cs.cache.Store(cache)
+ return cs
+}
diff --git a/backend/internal/service/account_websearch_test.go b/backend/internal/service/account_websearch_test.go
new file mode 100644
index 00000000..6ed69d4c
--- /dev/null
+++ b/backend/internal/service/account_websearch_test.go
@@ -0,0 +1,105 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestGetWebSearchEmulationMode_Enabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
+ }
+ require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_Disabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"},
+ }
+ require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_Default(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "default"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ }
+ // bool true → tolerant fallback → enabled (not default)
+ require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: false},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) {
+ var a *Account
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: nil,
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_MissingField(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) {
+ a := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index f9fd6742..419ddbc3 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -65,14 +65,14 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
-func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
- panic("unexpected")
-}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected")
+}
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct {
@@ -131,9 +131,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
-func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
- panic("unexpected")
-}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
@@ -158,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected")
}
+func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
+ panic("unexpected")
+}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct {
diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go
index f039612c..141466dc 100644
--- a/backend/internal/service/admin_service_clear_error_test.go
+++ b/backend/internal/service/admin_service_clear_error_test.go
@@ -12,12 +12,12 @@ import (
type accountRepoStubForClearAccountError struct {
mockAccountRepoForGemini
- account *Account
- clearErrorCalls int
- clearRateLimitCalls int
- clearAntigravityCalls int
+ account *Account
+ clearErrorCalls int
+ clearRateLimitCalls int
+ clearAntigravityCalls int
clearModelRateLimitCalls int
- clearTempUnschedCalls int
+ clearTempUnschedCalls int
}
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
@@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
resetAt := time.Now().Add(5 * time.Minute)
repo := &accountRepoStubForClearAccountError{
account: &Account{
- ID: 31,
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Status: StatusError,
- ErrorMessage: "refresh failed",
- RateLimitResetAt: &resetAt,
- TempUnschedulableUntil: &until,
+ ID: 31,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "refresh failed",
+ RateLimitResetAt: &resetAt,
+ TempUnschedulableUntil: &until,
TempUnschedulableReason: "missing refresh token",
},
}
diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go
index c2e96df1..b1660ea7 100644
--- a/backend/internal/service/api_key_auth_cache.go
+++ b/backend/internal/service/api_key_auth_cache.go
@@ -34,6 +34,15 @@ type APIKeyAuthUserSnapshot struct {
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
+
+ // Balance notification fields (required for CheckBalanceAfterDeduction)
+ Email string `json:"email"`
+ Username string `json:"username"`
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
+ BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
+ TotalRecharged float64 `json:"total_recharged"`
}
// APIKeyAuthGroupSnapshot 分组快照
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index 8069ed4f..2bd9a091 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -6,6 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
+ "log/slog"
"math/rand/v2"
"time"
@@ -13,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
-const apiKeyAuthSnapshotVersion = 3
+const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -99,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
- println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
+ slog.Warn("failed to start auth cache invalidation subscriber", "error", err)
}
}
@@ -219,11 +220,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{
- ID: apiKey.User.ID,
- Status: apiKey.User.Status,
- Role: apiKey.User.Role,
- Balance: apiKey.User.Balance,
- Concurrency: apiKey.User.Concurrency,
+ ID: apiKey.User.ID,
+ Status: apiKey.User.Status,
+ Role: apiKey.User.Role,
+ Balance: apiKey.User.Balance,
+ Concurrency: apiKey.User.Concurrency,
+ Email: apiKey.User.Email,
+ Username: apiKey.User.Username,
+ BalanceNotifyEnabled: apiKey.User.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
+ TotalRecharged: apiKey.User.TotalRecharged,
},
}
if apiKey.Group != nil {
@@ -274,11 +282,18 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
User: &User{
- ID: snapshot.User.ID,
- Status: snapshot.User.Status,
- Role: snapshot.User.Role,
- Balance: snapshot.User.Balance,
- Concurrency: snapshot.User.Concurrency,
+ ID: snapshot.User.ID,
+ Status: snapshot.User.Status,
+ Role: snapshot.User.Role,
+ Balance: snapshot.User.Balance,
+ Concurrency: snapshot.User.Concurrency,
+ Email: snapshot.User.Email,
+ Username: snapshot.User.Username,
+ BalanceNotifyEnabled: snapshot.User.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
+ TotalRecharged: snapshot.User.TotalRecharged,
},
}
if snapshot.Group != nil {
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 7b50e90d..103bafe7 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil
}
+func (s *emailCacheStub) GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *emailCacheStub) SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
+ return nil
+}
+
+func (s *emailCacheStub) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
+ return nil
+}
+
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
@@ -107,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
return nil
}
+func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
+ return 0, nil
+}
+
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
diff --git a/backend/internal/service/balance_notify_check_test.go b/backend/internal/service/balance_notify_check_test.go
new file mode 100644
index 00000000..7bb4cf9e
--- /dev/null
+++ b/backend/internal/service/balance_notify_check_test.go
@@ -0,0 +1,404 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an
+// in-memory settings repo and a non-nil emailService so that the guard-clause
+// nil-checks pass. The emailService is intentionally minimal — tests must
+// avoid crossing scenarios that would actually dispatch emails.
+func newBalanceNotifyServiceForTest() (*BalanceNotifyService, *mockSettingRepo) {
+ repo := newMockSettingRepo()
+ // EmailService is a concrete type; construct with the same repo so that
+ // any accidental fallback reads still succeed. Tests should not trigger a
+ // crossing that reaches SendEmail.
+ email := NewEmailService(repo, nil)
+ return NewBalanceNotifyService(email, repo, nil), repo
+}
+
+// ---------- guard clauses ----------
+
+func TestCheckBalanceAfterDeduction_NilUser(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ // Should not panic.
+ s.CheckBalanceAfterDeduction(context.Background(), nil, 100, 50)
+}
+
+func TestCheckBalanceAfterDeduction_UserNotifyDisabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
+ u := &User{ID: 1, BalanceNotifyEnabled: false}
+ // Even with a crossing, disabled flag short-circuits.
+ s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
+}
+
+func TestCheckBalanceAfterDeduction_GlobalDisabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
+ u := &User{ID: 1, BalanceNotifyEnabled: true}
+ s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
+}
+
+func TestCheckBalanceAfterDeduction_ThresholdZero(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "0"
+ u := &User{ID: 1, BalanceNotifyEnabled: true}
+ s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
+}
+
+func TestCheckBalanceAfterDeduction_UserThresholdOverride(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "100" // global default
+ customThreshold := 5.0
+ u := &User{
+ ID: 1,
+ BalanceNotifyEnabled: true,
+ BalanceNotifyThreshold: &customThreshold,
+ }
+ // User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not
+ // cross 5, so nothing fires (verified by absence of panic).
+ s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
+}
+
+func TestCheckBalanceAfterDeduction_NoCrossingNotFired(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
+ u := &User{ID: 1, BalanceNotifyEnabled: true}
+
+ // 100 -> 95, both remain above threshold=10, no crossing.
+ s.CheckBalanceAfterDeduction(context.Background(), u, 100, 5)
+ // 5 -> 3, both already below threshold, no crossing (only fires on first
+ // cross from above-to-below).
+ s.CheckBalanceAfterDeduction(context.Background(), u, 5, 2)
+}
+
+// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ----------
+
+func TestCheckAccountQuotaAfterIncrement_NilAccount(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ // Should not panic.
+ s.CheckAccountQuotaAfterIncrement(context.Background(), nil, 10, nil)
+}
+
+func TestCheckAccountQuotaAfterIncrement_ZeroCost(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
+ s.CheckAccountQuotaAfterIncrement(context.Background(), a, 0, nil)
+}
+
+func TestCheckAccountQuotaAfterIncrement_NegativeCost(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
+ s.CheckAccountQuotaAfterIncrement(context.Background(), a, -5, nil)
+}
+
+func TestCheckAccountQuotaAfterIncrement_GlobalDisabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
+ a := &Account{
+ ID: 1,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_notify_daily_enabled": true,
+ "quota_notify_daily_threshold": 100.0,
+ "quota_daily_limit": 1000.0,
+ "quota_daily_used": 950.0,
+ },
+ }
+ // Global disabled → no processing even if a dim would cross.
+ s.CheckAccountQuotaAfterIncrement(context.Background(), a, 100, nil)
+}
+
+// ---------- sanity: internal helpers still work ----------
+
+func TestGetBalanceNotifyConfig_AllFields(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "12.5"
+ repo.data[SettingKeyBalanceLowNotifyRechargeURL] = "https://example.com/pay"
+
+ enabled, threshold, url := s.getBalanceNotifyConfig(context.Background())
+ require.True(t, enabled)
+ require.Equal(t, 12.5, threshold)
+ require.Equal(t, "https://example.com/pay", url)
+}
+
+func TestGetBalanceNotifyConfig_Disabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
+
+ enabled, _, _ := s.getBalanceNotifyConfig(context.Background())
+ require.False(t, enabled)
+}
+
+func TestGetBalanceNotifyConfig_InvalidThreshold(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "not-a-number"
+
+ enabled, threshold, _ := s.getBalanceNotifyConfig(context.Background())
+ require.True(t, enabled)
+ require.Equal(t, 0.0, threshold)
+}
+
+func TestIsAccountQuotaNotifyEnabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+
+ // Missing key → false
+ require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
+
+ // Explicit "false"
+ repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
+ require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
+
+ // Explicit "true"
+ repo.data[SettingKeyAccountQuotaNotifyEnabled] = "true"
+ require.True(t, s.isAccountQuotaNotifyEnabled(context.Background()))
+}
+
+func TestGetSiteName_FallsBackToDefault(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ name := s.getSiteName(context.Background())
+ require.Equal(t, defaultSiteName, name)
+}
+
+func TestGetSiteName_Configured(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeySiteName] = "My Site"
+ require.Equal(t, "My Site", s.getSiteName(context.Background()))
+}
+
+// ---------- crossedDownward ----------
+
+func TestCrossedDownward_CrossesBelow(t *testing.T) {
+ // oldBalance > threshold, newBalance < threshold → true
+ require.True(t, crossedDownward(100, 5, 10))
+}
+
+func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) {
+ // oldBalance > threshold, newBalance == threshold → false (not below)
+ require.False(t, crossedDownward(100, 10, 10))
+}
+
+func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) {
+ // oldBalance == threshold, newBalance < threshold → true
+ // (at-or-above → below counts as a crossing)
+ require.True(t, crossedDownward(10, 5, 10))
+}
+
+func TestCrossedDownward_AlreadyBelow(t *testing.T) {
+ // oldBalance < threshold → false (already below, no new crossing)
+ require.False(t, crossedDownward(5, 3, 10))
+}
+
+func TestCrossedDownward_BothAbove(t *testing.T) {
+ // oldBalance > threshold, newBalance > threshold → false (no crossing)
+ require.False(t, crossedDownward(100, 50, 10))
+}
+
+func TestCrossedDownward_ZeroThreshold(t *testing.T) {
+ // threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives
+ // Typical case: positive balances should not fire when threshold is 0.
+ require.False(t, crossedDownward(10, 5, 0))
+ require.False(t, crossedDownward(0, 0, 0))
+}
+
+func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) {
+ // Edge case: newBalance goes negative with threshold=0.
+ require.True(t, crossedDownward(5, -1, 0))
+}
+
+func TestCrossedDownward_NegativeValues(t *testing.T) {
+ // Both already negative, threshold is positive → no crossing (already below).
+ require.False(t, crossedDownward(-5, -10, 10))
+}
+
+func TestCrossedDownward_LargeDecrement(t *testing.T) {
+ // A single large deduction crosses the threshold.
+ require.True(t, crossedDownward(1000, 0.5, 100))
+}
+
+func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) {
+ // A tiny deduction stays above threshold.
+ require.False(t, crossedDownward(100, 99.99, 10))
+}
+
+// ---------- checkQuotaDimCrossings ----------
+
+func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // Empty dims → no crossing, no panic.
+ s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite")
+ s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: false, // disabled
+ threshold: 100,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 950,
+ limit: 1000,
+ },
+ }
+ // Disabled dimension should be skipped even if crossing would occur.
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 0, // zero threshold
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 950,
+ limit: 1000,
+ },
+ }
+ // Zero threshold → skipped.
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
+ // currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing.
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 400,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 300,
+ limit: 1000,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
+ // currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing.
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 400,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 800,
+ limit: 1000,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200
+ // Negative resolved threshold → skipped.
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 1200,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 950,
+ limit: 1000,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700
+ // currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing.
+ dims := []quotaDim{
+ {
+ name: quotaDimWeekly,
+ enabled: true,
+ threshold: 30,
+ thresholdType: thresholdTypePercentage,
+ currentUsed: 500,
+ limit: 1000,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // limit=0 → resolvedThreshold returns 0 → skipped.
+ dims := []quotaDim{
+ {
+ name: quotaDimTotal,
+ enabled: true,
+ threshold: 100,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 50,
+ limit: 0,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // dim1: no crossing (both below effective threshold)
+ // dim2: disabled (skipped)
+ // dim3: zero threshold (skipped)
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 400,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below
+ limit: 1000,
+ },
+ {
+ name: quotaDimWeekly,
+ enabled: false,
+ threshold: 100,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 900,
+ limit: 1000,
+ },
+ {
+ name: quotaDimTotal,
+ enabled: true,
+ threshold: 0,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 500,
+ limit: 1000,
+ },
+ }
+ // None should trigger. No panic expected.
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
diff --git a/backend/internal/service/balance_notify_email_body_test.go b/backend/internal/service/balance_notify_email_body_test.go
new file mode 100644
index 00000000..aee5a5bc
--- /dev/null
+++ b/backend/internal/service/balance_notify_email_body_test.go
@@ -0,0 +1,147 @@
+//go:build unit
+
+package service
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// These tests guard against fmt.Sprintf arg-count mismatches in the email
+// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in
+// the output, which these assertions will catch.
+
+// ---------- buildBalanceLowEmailBody ----------
+
+func TestBuildBalanceLowEmailBody_ContainsRequiredFields(t *testing.T) {
+ s := &BalanceNotifyService{}
+ body := s.buildBalanceLowEmailBody("Alice", 3.14, 10.0, "MySite", "")
+
+ // All substituted values should appear in the output.
+ require.Contains(t, body, "MySite")
+ require.Contains(t, body, "Alice")
+ require.Contains(t, body, "$3.14")
+ require.Contains(t, body, "$10.00")
+
+ // No fmt.Sprintf format error markers.
+ require.NotContains(t, body, "%!")
+ require.NotContains(t, body, "MISSING")
+ require.NotContains(t, body, "EXTRA")
+}
+
+func TestBuildBalanceLowEmailBody_WithRechargeURL(t *testing.T) {
+ s := &BalanceNotifyService{}
+ body := s.buildBalanceLowEmailBody("Bob", 5.0, 20.0, "Site", "https://example.com/pay")
+
+ // The recharge anchor element should appear with the URL.
+ require.Contains(t, body, `href="https://example.com/pay"`)
+ require.Contains(t, body, "立即充值")
+ require.NotContains(t, body, "%!")
+}
+
+func TestBuildBalanceLowEmailBody_RechargeURLEscaped(t *testing.T) {
+ s := &BalanceNotifyService{}
+ // Try a URL with characters that need HTML escaping.
+ body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", `https://example.com/?a=1&b=
diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue
index 2e4eea0c..1c023fb3 100644
--- a/frontend/src/components/account/AccountUsageCell.vue
+++ b/frontend/src/components/account/AccountUsageCell.vue
@@ -439,15 +439,20 @@
+
+
+
+
- {{ t('admin.accounts.quotaLimitHint') }} + {{ t('admin.accounts.quotaControl.hint') }}
+ {{ t('admin.accounts.quotaLimitHint') }} +
++ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }} +
++ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }} +
++ {{ t('admin.accounts.quotaControl.hint') }} +
+{{ t('admin.accounts.quotaLimitHint') }}
@@ -1167,6 +1241,16 @@ :weeklyResetDay="editWeeklyResetDay" :weeklyResetHour="editWeeklyResetHour" :resetTimezone="editResetTimezone" + :quotaNotifyGlobalEnabled="quotaNotifyGlobalEnabled" + :quotaNotifyDailyEnabled="quotaNotifyState.daily.enabled" + :quotaNotifyDailyThreshold="quotaNotifyState.daily.threshold" + :quotaNotifyDailyThresholdType="quotaNotifyState.daily.thresholdType" + :quotaNotifyWeeklyEnabled="quotaNotifyState.weekly.enabled" + :quotaNotifyWeeklyThreshold="quotaNotifyState.weekly.threshold" + :quotaNotifyWeeklyThresholdType="quotaNotifyState.weekly.thresholdType" + :quotaNotifyTotalEnabled="quotaNotifyState.total.enabled" + :quotaNotifyTotalThreshold="quotaNotifyState.total.threshold" + :quotaNotifyTotalThresholdType="quotaNotifyState.total.thresholdType" @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" @@ -1176,6 +1260,15 @@ @update:weeklyResetDay="editWeeklyResetDay = $event" @update:weeklyResetHour="editWeeklyResetHour = $event" @update:resetTimezone="editResetTimezone = $event" + @update:quotaNotifyDailyEnabled="quotaNotifyState.daily.enabled = $event" + @update:quotaNotifyDailyThreshold="quotaNotifyState.daily.threshold = $event" + @update:quotaNotifyDailyThresholdType="quotaNotifyState.daily.thresholdType = $event" + @update:quotaNotifyWeeklyEnabled="quotaNotifyState.weekly.enabled = $event" + @update:quotaNotifyWeeklyThreshold="quotaNotifyState.weekly.threshold = $event" + @update:quotaNotifyWeeklyThresholdType="quotaNotifyState.weekly.thresholdType = $event" + @update:quotaNotifyTotalEnabled="quotaNotifyState.total.enabled = $event" + @update:quotaNotifyTotalThreshold="quotaNotifyState.total.threshold = $event" + @update:quotaNotifyTotalThresholdType="quotaNotifyState.total.thresholdType = $event" />{{ hintRolling }}
+- {{ t('admin.accounts.quotaLimitToggleHint') }} -
++ {{ t('admin.accounts.quotaLimitToggleHint') }} +
+- - {{ t('admin.accounts.quotaDailyLimitHintFixed', { hour: String(dailyResetHour ?? 0).padStart(2, '0'), timezone: resetTimezone || 'UTC' }) }} - - - {{ t('admin.accounts.quotaDailyLimitHint') }} - -
-- - {{ t('admin.accounts.quotaWeeklyLimitHintFixed', { day: t('admin.accounts.dayOfWeek.' + (dayOptions.find(d => d.value === (weeklyResetDay ?? 1))?.key || 'monday')), hour: String(weeklyResetHour ?? 0).padStart(2, '0'), timezone: resetTimezone || 'UTC' }) }} - - - {{ t('admin.accounts.quotaWeeklyLimitHint') }} - -
-{{ t('admin.accounts.quotaTotalLimitHint') }}
-+ {{ t('admin.channels.form.applyPricingToAccountStatsDesc') }} +
++ {{ t('admin.channels.form.webSearchEmulationHint') }} +
++ {{ t('admin.channels.form.noRulesConfigured') }} +
+ ++ {{ t('admin.channels.form.noGroupsInChannel') }} +
++ {{ t('admin.channels.form.ruleAccountsHint') }} +
++ {{ t('admin.settings.webSearchEmulation.description') }} +
++ {{ t('admin.settings.webSearchEmulation.enabledHint') }} +
+{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}
+{{ t('admin.settings.webSearchEmulation.subscribedAtHint') }}
++ {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }} +
+{{ r.snippet }}
++ {{ t('admin.settings.balanceNotify.description') }} +
+{{ t('admin.settings.balanceNotify.thresholdHint') }}
+{{ t('admin.settings.balanceNotify.rechargeUrlHint') }}
++ {{ t('admin.settings.quotaNotify.description') }} +
+{{ t('admin.settings.quotaNotify.emailsHint') }}
+{{ errorMessage }}
@@ -110,9 +110,9 @@