From 1cd033e521b02e23fabf23628303663985b34114 Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 9 Mar 2026 19:24:19 +0800
Subject: [PATCH 01/88] style: apply gofmt formatting
Co-Authored-By: Claude Opus 4.6
---
backend/internal/repository/gateway_cache.go | 257 +++++++++++++++++-
.../service/admin_service_apikey_test.go | 72 +----
backend/internal/service/user_service_test.go | 9 +-
3 files changed, 257 insertions(+), 81 deletions(-)
diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go
index 58291b66..ec4bf40e 100644
--- a/backend/internal/repository/gateway_cache.go
+++ b/backend/internal/repository/gateway_cache.go
@@ -2,14 +2,42 @@ package repository
import (
"context"
+ _ "embed"
"fmt"
+ "strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
-const stickySessionPrefix = "sticky_session:"
+const (
+ stickySessionPrefix = "sticky_session:"
+ clientAffinityPrefix = "client_affinity:"
+ clientAffinityReversePrefix = "client_affinity_rev:"
+)
+
+var (
+ //go:embed lua/get_affinity.lua
+ getAffinityLua string
+ //go:embed lua/update_affinity.lua
+ updateAffinityLua string
+ //go:embed lua/get_affinity_count.lua
+ getAffinityCountLua string
+ //go:embed lua/get_affinity_clients.lua
+ getAffinityClientsLua string
+ //go:embed lua/get_affinity_clients_with_scores.lua
+ getAffinityClientsWithScoresLua string
+ //go:embed lua/clear_account_affinity.lua
+ clearAccountAffinityLua string
+
+ getAffinityScript = redis.NewScript(getAffinityLua)
+ updateAffinityScript = redis.NewScript(updateAffinityLua)
+ getAffinityCountScript = redis.NewScript(getAffinityCountLua)
+ getAffinityClientsScript = redis.NewScript(getAffinityClientsLua)
+ getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua)
+ clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua)
+)
type gatewayCache struct {
rdb *redis.Client
@@ -19,6 +47,16 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}
+// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。
+// Pipeline 中的 Script.Run 只发送 EVALSHA,如果 Redis 重启过导致脚本缓存丢失,
+// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。
+func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) {
+ exists, err := script.Exists(ctx, rdb).Result()
+ if err != nil || len(exists) == 0 || !exists[0] {
+ _ = script.Load(ctx, rdb).Err()
+ }
+}
+
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
@@ -41,13 +79,218 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
-// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
-// 以便下次请求能够重新选择可用账号。
-//
-// DeleteSessionAccountID removes the sticky session binding for the given session.
-// Called when the bound account becomes unavailable (e.g., error status, disabled,
-// or unschedulable), allowing subsequent requests to select a new available account.
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
+
+// buildAffinityKey 构建正向亲和 key(client → accounts)
+// 格式: client_affinity:{groupID}:{clientID}
+func buildAffinityKey(groupID int64, clientID string) string {
+ return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID)
+}
+
+// buildAffinityReverseKey 构建反向亲和 key(account → clients)
+// 格式: client_affinity_rev:{groupID}:{accountID}
+func buildAffinityReverseKey(groupID int64, accountID int64) string {
+ return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID)
+}
+
+func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) {
+ key := buildAffinityKey(groupID, clientID)
+ now := time.Now().Unix()
+ expireThreshold := now - int64(ttl.Seconds())
+
+ result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice()
+ if err != nil {
+ if err == redis.Nil {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ accountIDs := make([]int64, 0, len(result))
+ for _, s := range result {
+ id, err := strconv.ParseInt(s, 10, 64)
+ if err != nil {
+ continue
+ }
+ accountIDs = append(accountIDs, id)
+ }
+ return accountIDs, nil
+}
+
+func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error {
+ fwdKey := buildAffinityKey(groupID, clientID)
+ revKey := buildAffinityReverseKey(groupID, accountID)
+ now := time.Now().Unix()
+ ttlSeconds := int64(ttl.Seconds())
+ expireThreshold := now - ttlSeconds
+
+ return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey},
+ now, ttlSeconds, accountID, expireThreshold, clientID,
+ ).Err()
+}
+
+// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员)
+func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) {
+ if len(accountIDs) == 0 {
+ return map[int64]int64{}, nil
+ }
+
+ now := time.Now().Unix()
+ expireThreshold := now - int64(ttl.Seconds())
+
+ ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript)
+
+ pipe := c.rdb.Pipeline()
+ cmds := make([]*redis.Cmd, len(accountIDs))
+ for i, accID := range accountIDs {
+ key := buildAffinityReverseKey(groupID, accID)
+ cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold)
+ }
+ _, err := pipe.Exec(ctx)
+ if err != nil && err != redis.Nil {
+ return nil, err
+ }
+
+ result := make(map[int64]int64, len(accountIDs))
+ for i, accID := range accountIDs {
+ count, _ := cmds[i].Int64()
+ result[accID] = count
+ }
+ return result, nil
+}
+
+// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。
+// accountGroups: map[accountID][]groupID,对每个 (groupID, accountID) 组合查询反向索引。
+func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) {
+ if len(accountGroups) == 0 {
+ return map[int64][]string{}, nil
+ }
+
+ now := time.Now().Unix()
+ expireThreshold := now - int64(ttl.Seconds())
+
+ // 构建所有 (accountID, groupID) 组合的查询
+ type queryItem struct {
+ accountID int64
+ groupID int64
+ }
+ var queries []queryItem
+ for accID, groupIDs := range accountGroups {
+ for _, gID := range groupIDs {
+ queries = append(queries, queryItem{accountID: accID, groupID: gID})
+ }
+ }
+
+ ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript)
+
+ pipe := c.rdb.Pipeline()
+ cmds := make([]*redis.Cmd, len(queries))
+ for i, q := range queries {
+ key := buildAffinityReverseKey(q.groupID, q.accountID)
+ cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold)
+ }
+ _, err := pipe.Exec(ctx)
+ if err != nil && err != redis.Nil {
+ return nil, err
+ }
+
+ // 合并结果:同一个 accountID 跨多个 group 的 clientID 去重
+ result := make(map[int64][]string, len(accountGroups))
+ seen := make(map[int64]map[string]struct{}, len(accountGroups))
+ for i, q := range queries {
+ clients, _ := cmds[i].StringSlice()
+ if len(clients) == 0 {
+ continue
+ }
+ if seen[q.accountID] == nil {
+ seen[q.accountID] = make(map[string]struct{})
+ }
+ for _, clientID := range clients {
+ if _, exists := seen[q.accountID][clientID]; !exists {
+ seen[q.accountID][clientID] = struct{}{}
+ result[q.accountID] = append(result[q.accountID], clientID)
+ }
+ }
+ }
+ return result, nil
+}
+
+// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。
+func (c *gatewayCache) GetAccountAffinityClientsWithScores(
+ ctx context.Context,
+ accountID int64,
+ groupIDs []int64,
+ ttl time.Duration,
+) ([]service.AffinityClient, error) {
+ if len(groupIDs) == 0 {
+ return nil, nil
+ }
+
+ now := time.Now().Unix()
+ expireThreshold := now - int64(ttl.Seconds())
+
+ ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript)
+
+ pipe := c.rdb.Pipeline()
+ cmds := make([]*redis.Cmd, len(groupIDs))
+ for i, gID := range groupIDs {
+ key := buildAffinityReverseKey(gID, accountID)
+ cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold)
+ }
+ _, err := pipe.Exec(ctx)
+ if err != nil && err != redis.Nil {
+ return nil, err
+ }
+
+ // 合并跨组结果,同一 clientID 取最近的 lastActive
+ seen := make(map[string]int64) // clientID → max timestamp
+ for _, cmd := range cmds {
+ vals, _ := cmd.StringSlice()
+ // vals 格式: [clientID1, score1, clientID2, score2, ...]
+ for j := 0; j+1 < len(vals); j += 2 {
+ clientID := vals[j]
+ ts, _ := strconv.ParseInt(vals[j+1], 10, 64)
+ if existing, ok := seen[clientID]; !ok || ts > existing {
+ seen[clientID] = ts
+ }
+ }
+ }
+
+ result := make([]service.AffinityClient, 0, len(seen))
+ for clientID, ts := range seen {
+ result = append(result, service.AffinityClient{
+ ClientID: clientID,
+ LastActive: time.Unix(ts, 0),
+ })
+ }
+
+ // 按最后活跃时间降序排序
+ service.SortAffinityClients(result)
+
+ return result, nil
+}
+
+// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。
+// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端,
+// 从每个客户端的正向索引中移除该账号,然后删除反向索引。
+func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error {
+ if len(groupIDs) == 0 {
+ return nil
+ }
+
+ ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript)
+
+ pipe := c.rdb.Pipeline()
+ for _, gID := range groupIDs {
+ revKey := buildAffinityReverseKey(gID, accountID)
+ clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID)
+ }
+ _, err := pipe.Exec(ctx)
+ if err != nil && err != redis.Nil {
+ return err
+ }
+ return nil
+}
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index f9fd6742..5c18a438 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -65,9 +65,6 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
-func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
- panic("unexpected")
-}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected")
}
@@ -131,9 +128,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
-func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
- panic("unexpected")
-}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
@@ -200,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
-func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
+func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
@@ -216,29 +210,6 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS
panic("unexpected")
}
-type userSubRepoStubForGroupUpdate struct {
- userSubRepoNoop
- getActiveSub *UserSubscription
- getActiveErr error
- called bool
- calledUserID int64
- calledGroupID int64
-}
-
-func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
- s.called = true
- s.calledUserID = userID
- s.calledGroupID = groupID
- if s.getActiveErr != nil {
- return nil, s.getActiveErr
- }
- if s.getActiveSub == nil {
- return nil, ErrSubscriptionNotFound
- }
- clone := *s.getActiveSub
- return &clone, nil
-}
-
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
@@ -431,49 +402,14 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
- groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
- userRepo := &userRepoStubForGroupUpdate{}
- userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
- svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
-
- // 无有效订阅时应拒绝绑定
- _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
- require.Error(t, err)
- require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
- require.True(t, userSubRepo.called)
- require.Equal(t, int64(42), userSubRepo.calledUserID)
- require.Equal(t, int64(10), userSubRepo.calledGroupID)
- require.False(t, userRepo.addGroupCalled)
-}
-
-func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
- existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
- apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
- groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
+ // 订阅类型分组应被阻止绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
- require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
- require.False(t, userRepo.addGroupCalled)
-}
-
-func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
- existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
- apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
- groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
- userRepo := &userRepoStubForGroupUpdate{}
- userSubRepo := &userSubRepoStubForGroupUpdate{
- getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
- }
- svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
-
- got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
- require.NoError(t, err)
- require.True(t, userSubRepo.called)
- require.NotNil(t, got.APIKey.GroupID)
- require.Equal(t, int64(10), *got.APIKey.GroupID)
+ require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
require.False(t, userRepo.addGroupCalled)
}
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index e88694f5..7f6c748f 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -46,12 +46,9 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil
}
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
-func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
-func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
-func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
// --- mock: APIKeyAuthCacheInvalidator ---
From 3de77130175138146981d2e8d34ec8eb19a614b9 Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 30 Mar 2026 21:47:06 +0800
Subject: [PATCH 02/88] =?UTF-8?q?fix(channel):=20splice=E6=9B=BF=E6=8D=A2m?=
=?UTF-8?q?odel=5Fpricing=E6=9D=A1=E7=9B=AE=20+=20=E5=A2=9E=E5=BC=BA?=
=?UTF-8?q?=E8=B0=83=E8=AF=95=E6=97=A5=E5=BF=97?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/cmd/server/VERSION | 2 +-
frontend/src/views/admin/ChannelsView.vue | 1 +
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 4b9b35d8..1ebb081e 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.112
+0.1.105.13
diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue
index ebfc1b5a..5c2f153b 100644
--- a/frontend/src/views/admin/ChannelsView.vue
+++ b/frontend/src/views/admin/ChannelsView.vue
@@ -970,6 +970,7 @@ async function handleSubmit() {
}
const { group_ids, model_pricing, model_mapping } = formToAPI()
+ console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing))
submitting.value = true
try {
From 2dce4306b4409e355f6ff265fa30ae7a2d3a6221 Mon Sep 17 00:00:00 2001
From: erio
Date: Thu, 2 Apr 2026 13:24:30 +0800
Subject: [PATCH 03/88] refactor: move channel model restriction from handler
to scheduling phase
Move the model pricing restriction check from 8 handler entry points
to the account scheduling phase (SelectAccountForModelWithExclusions /
SelectAccountWithLoadAwareness), aligning restriction with billing:
- requested: check original request model against pricing list
- channel_mapped: check channel-mapped model against pricing list
- upstream: per-account check using account-mapped model
Handler layer now only resolves channel mapping (no restriction).
Scheduling layer performs pre-check for requested/channel_mapped,
and per-account filtering for upstream billing source.
---
.../gateway_handler_chat_completions.go | 2 +-
.../handler/gateway_handler_responses.go | 2 +-
.../internal/handler/gemini_v1beta_handler.go | 14 +-
.../handler/openai_chat_completions.go | 2 +-
.../handler/openai_gateway_handler.go | 46 +-
backend/internal/service/gateway_service.go | 1202 +++++++++++------
6 files changed, 793 insertions(+), 475 deletions(-)
diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go
index be267332..abe2a1e5 100644
--- a/backend/internal/handler/gateway_handler_chat_completions.go
+++ b/backend/internal/handler/gateway_handler_chat_completions.go
@@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
- // 解析渠道级模型映射
+ // 解析渠道级模型映射 + 限制检查
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction
diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go
index e908eb9e..cf877182 100644
--- a/backend/internal/handler/gateway_handler_responses.go
+++ b/backend/internal/handler/gateway_handler_responses.go
@@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
- // 解析渠道级模型映射
+ // 解析渠道级模型映射 + 限制检查
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction:
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index d200c17c..ff63bc7f 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
googleError(c, http.StatusBadGateway, err.Error())
return
}
- if shouldFallbackGeminiModel(modelName, res) {
+ if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
@@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
- // 解析渠道级模型映射
+ // 解析渠道级模型映射 + 限制检查
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {
@@ -682,16 +682,6 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
return false
}
-func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool {
- if shouldFallbackGeminiModels(res) {
- return true
- }
- if res == nil || res.StatusCode != http.StatusNotFound {
- return false
- }
- return gemini.HasFallbackModel(modelName)
-}
-
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go
index 991cbb91..ada401c9 100644
--- a/backend/internal/handler/openai_chat_completions.go
+++ b/backend/internal/handler/openai_chat_completions.go
@@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
- // 解析渠道级模型映射
+ // 解析渠道级模型映射 + 限制检查
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if h.errorPassthroughService != nil {
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 5319b55d..2b081617 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -47,13 +47,6 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
}
-func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
- if apiKey == nil || apiKey.Group == nil {
- return ""
- }
- return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
-}
-
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
@@ -557,8 +550,6 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
return
}
reqModel := modelResult.String()
- routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
- preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel)
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
@@ -617,20 +608,17 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
- effectiveMappedModel := preferredMappedModel
for {
- currentRoutingModel := routingModel
- if effectiveMappedModel != "" {
- currentRoutingModel = effectiveMappedModel
- }
+ // 清除上一次迭代的降级模型标记,避免残留影响本次迭代
+ c.Set("openai_messages_fallback_model", "")
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"", // no previous_response_id
sessionHash,
- currentRoutingModel,
+ reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
@@ -639,7 +627,29 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
+ // 首次调度失败 + 有默认映射模型 → 用默认模型重试
if len(failedAccountIDs) == 0 {
+ defaultModel := ""
+ if apiKey.Group != nil {
+ defaultModel = apiKey.Group.DefaultMappedModel
+ }
+ if defaultModel != "" && defaultModel != reqModel {
+ reqLog.Info("openai_messages.fallback_to_default_model",
+ zap.String("default_mapped_model", defaultModel),
+ )
+ selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
+ c.Request.Context(),
+ apiKey.GroupID,
+ "",
+ sessionHash,
+ defaultModel,
+ failedAccountIDs,
+ service.OpenAIUpstreamTransportAny,
+ )
+ if err == nil && selection != nil {
+ c.Set("openai_messages_fallback_model", defaultModel)
+ }
+ }
if err != nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
@@ -671,7 +681,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
- defaultMappedModel := strings.TrimSpace(effectiveMappedModel)
+ // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
+ // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
+ defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
// 应用渠道模型映射到请求体
forwardBody := body
if channelMappingMsg.Mapped {
@@ -1106,7 +1118,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
- // 解析渠道级模型映射
+ // 解析渠道级模型映射 + 限制检查
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func()
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 8b0bdc2a..33ab38f2 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -12,7 +12,6 @@ import (
"log/slog"
mathrand "math/rand"
"net/http"
- "net/url"
"os"
"path/filepath"
"regexp"
@@ -42,7 +41,8 @@ import (
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
- stickySessionTTL = time.Hour // 粘性会话TTL
+ stickySessionTTL = time.Hour // 粘性会话TTL
+ ClientAffinityTTL = 24 * time.Hour // 客户端亲和TTL
defaultMaxLineSize = 500 * 1024 * 1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
@@ -60,14 +60,28 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
+// MediaType 媒体类型常量
+const (
+ MediaTypeImage = "image"
+ MediaTypeVideo = "video"
+ MediaTypePrompt = "prompt"
+)
+
+const (
+ claudeMaxMessageOverheadTokens = 3
+ claudeMaxBlockOverheadTokens = 1
+ claudeMaxUnknownContentTokens = 4
+)
+
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
- account *Account
- loadInfo *AccountLoadInfo
+ account *Account
+ loadInfo *AccountLoadInfo
+ affinityCount int64 // 亲和客户端数量(反向索引),越少越优先
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
@@ -331,6 +345,10 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
+ // clientIDFromMetadataRegex 从 metadata.user_id 中提取客户端 ID(64位 hex)
+ // 格式: user_{64位hex}_account_...
+ clientIDFromMetadataRegex = regexp.MustCompile(`^user_([a-f0-9]{64})_account_`)
+
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
@@ -348,6 +366,12 @@ var ErrNoAvailableAccounts = errors.New("no available accounts")
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
+// ErrAffinityNoSwitch 表示亲和账号不可用且不允许切换到其他账号
+var ErrAffinityNoSwitch = errors.New("affinity account unavailable and switching is disabled")
+
+// ErrAffinityLimitExceeded 表示亲和客户端限制已达上限
+var ErrAffinityLimitExceeded = errors.New("affinity client limit exceeded")
+
// allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{
"accept": true,
@@ -369,8 +393,6 @@ var allowedHeaders = map[string]bool{
"user-agent": true,
"content-type": true,
"accept-encoding": true,
- "x-claude-code-session-id": true,
- "x-client-request-id": true,
}
// GatewayCache 定义网关服务的缓存操作接口。
@@ -391,6 +413,39 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
+
+ // GetAffinityAccounts 获取亲和账号列表(按最近使用降序),同时清理过期成员
+ GetAffinityAccounts(ctx context.Context, groupID int64, userID int64, clientID string, ttl time.Duration) ([]int64, error)
+ // UpdateAffinity 添加/更新亲和关系(更新 score 为当前时间戳,刷新 key TTL)
+ UpdateAffinity(ctx context.Context, groupID int64, userID int64, clientID string, accountID int64, ttl time.Duration) error
+ // GetAccountAffinityCountBatch 批量获取账号的亲和成员数量(惰性清理过期成员)
+ GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error)
+ // GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和成员列表(去重)
+ // accountGroups: map[accountID][]groupID
+ // 返回值成员格式为 {userID}/{clientID}
+ GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error)
+ // GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间)
+ GetAccountAffinityClientsWithScores(ctx context.Context, accountID int64, groupIDs []int64, ttl time.Duration) ([]AffinityClient, error)
+ // ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)
+ // 用于账号关闭亲和时立即清理旧绑定
+ ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error
+ // GetAffinityMultiCount 获取账号的多维度亲和计数
+ // 返回: uniqueUsers, uniqueClients, perUserClients
+ GetAffinityMultiCount(ctx context.Context, groupID int64, accountID int64, targetUserID int64, ttl time.Duration) (users, clients, perUser int64, err error)
+}
+
+// AffinityClient 亲和客户端信息(含用户 ID 和最后活跃时间)
+type AffinityClient struct {
+ UserID int64 `json:"user_id"`
+ ClientID string `json:"client_id"`
+ LastActive time.Time `json:"last_active"`
+}
+
+// SortAffinityClients 按最后活跃时间降序排序
+func SortAffinityClients(clients []AffinityClient) {
+ sort.Slice(clients, func(i, j int) bool {
+ return clients[i].LastActive.After(clients[j].LastActive)
+ })
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -461,6 +516,20 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
return false
}
+// extractClientIDFromMetadata 从 metadata.user_id 中提取客户端 ID(64位 hex)。
+// 格式: user_{64位hex}_account_..._session_...
+// 返回空字符串表示无法提取(非 Claude Code/Console 客户端)。
+func extractClientIDFromMetadata(metadataUserID string) string {
+ if metadataUserID == "" {
+ return ""
+ }
+ matches := clientIDFromMetadataRegex.FindStringSubmatch(metadataUserID)
+ if matches == nil {
+ return ""
+ }
+ return matches[1]
+}
+
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
@@ -504,6 +573,9 @@ type ForwardResult struct {
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
+ // Sora 媒体字段
+ MediaType string // image / video / prompt
+ MediaURL string // 生成后的媒体地址(可选)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
@@ -1162,6 +1234,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
+ // 渠道定价限制预检查(requested / channel_mapped 基准)
+ if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
+ return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
+ }
+
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
@@ -1180,32 +1257,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
platform = PlatformAnthropic
}
- // Claude Code 限制可能已将 groupID 解析为 fallback group,
- // 渠道限制预检查必须使用解析后的分组。
- if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
- slog.Warn("channel pricing restriction blocked request",
- "group_id", derefGroupID(groupID),
- "model", requestedModel)
- return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
- }
-
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
- account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
- if err != nil {
- return nil, err
- }
- return s.hydrateSelectedAccount(ctx, account)
+ return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
- account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
- if err != nil {
- return nil, err
- }
- return s.hydrateSelectedAccount(ctx, account)
+ return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
@@ -1213,6 +1273,11 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
+ // 渠道定价限制预检查(requested / channel_mapped 基准)
+ if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
+ return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
+ }
+
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
@@ -1233,15 +1298,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
ctx = s.withGroupContext(ctx, group)
- // Claude Code 限制可能已将 groupID 解析为 fallback group,
- // 渠道限制预检查必须使用解析后的分组。
- if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
- slog.Warn("channel pricing restriction blocked request",
- "group_id", derefGroupID(groupID),
- "model", requestedModel)
- return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
- }
-
var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch
@@ -1251,6 +1307,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
+ // 提取客户端 ID(用于客户端亲和调度)
+ affinityClientID := extractClientIDFromMetadata(metadataUserID)
+ affinityUserID := sub2apiUserID
+
if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := ""
if group != nil {
@@ -1272,6 +1332,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err != nil {
return nil, err
}
+ if shouldFilterAccountWithoutClientID(account, affinityClientID) {
+ localExcluded[account.ID] = struct{}{}
+ continue
+ }
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
@@ -1281,7 +1345,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择
}
- return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
+ return &AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
}
// 对于等待计划的情况,也需要先检查会话限制
@@ -1293,20 +1361,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
- return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
- AccountID: account.ID,
- MaxConcurrency: account.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- })
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: account.ID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
}
}
- return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
- AccountID: account.ID,
- MaxConcurrency: account.Concurrency,
- Timeout: cfg.FallbackWaitTimeout,
- MaxWaiting: cfg.FallbackMaxWaiting,
- })
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: account.ID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.FallbackWaitTimeout,
+ MaxWaiting: cfg.FallbackMaxWaiting,
+ },
+ }, nil
}
}
@@ -1323,12 +1397,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err != nil {
return nil, err
}
+ accounts = filterAccountsWithoutClientID(accounts, affinityClientID)
if len(accounts) == 0 {
return nil, ErrNoAvailableAccounts
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
+ // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
+ accountByID := make(map[int64]*Account, len(accounts))
+ for i := range accounts {
+ accountByID[accounts[i].ID] = &accounts[i]
+ }
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
@@ -1336,12 +1416,19 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
_, excluded := excludedIDs[accountID]
return excluded
}
-
- // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
- accountByID := make(map[int64]*Account, len(accounts))
- for i := range accounts {
- accountByID[accounts[i].ID] = &accounts[i]
- }
+ affinityFlow := newGatewayAffinityFlow(
+ s,
+ ctx,
+ groupID,
+ sessionHash,
+ requestedModel,
+ affinityClientID,
+ affinityUserID,
+ platform,
+ useMixed,
+ accountByID,
+ isExcluded,
+ )
// 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64
@@ -1430,76 +1517,53 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
- var stickyCacheMissReason string
-
- gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
+ if s.isAccountSchedulableForSelection(stickyAccount) &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) &&
- s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
+ s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
- rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
-
- if rpmPass { // 粘性会话窗口费用+RPM 检查
+ s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位
- stickyCacheMissReason = "session_limit"
// 继续到负载感知选择
} else {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
- return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil)
+ return &AccountSelectionResult{
+ Account: stickyAccount,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
}
}
- if stickyCacheMissReason == "" {
- waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
- if waitingCount < cfg.StickySessionMaxWaiting {
- // 会话数量限制检查(等待计划也需要占用会话配额)
- if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
- stickyCacheMissReason = "session_limit"
- // 会话限制已满,继续到负载感知选择
- } else {
- return &AccountSelectionResult{
- Account: stickyAccount,
- WaitPlan: &AccountWaitPlan{
- AccountID: stickyAccountID,
- MaxConcurrency: stickyAccount.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- },
- }, nil
- }
+ waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
+ if waitingCount < cfg.StickySessionMaxWaiting {
+ // 会话数量限制检查(等待计划也需要占用会话配额)
+ if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
+ // 会话限制已满,继续到负载感知选择
} else {
- stickyCacheMissReason = "wait_queue_full"
+ return &AccountSelectionResult{
+ Account: stickyAccount,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: stickyAccountID,
+ MaxConcurrency: stickyAccount.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
- } else if !gatePass {
- stickyCacheMissReason = "gate_check"
- } else {
- stickyCacheMissReason = "rpm_red"
- }
-
- // 记录粘性缓存未命中的结构化日志
- if stickyCacheMissReason != "" {
- baseRPM := stickyAccount.GetBaseRPM()
- var currentRPM int
- if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
- currentRPM = count
- }
- logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
- stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
}
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
- logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
- stickyAccountID, shortSessionHash(sessionHash))
}
}
}
@@ -1527,7 +1591,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if len(routingAvailable) > 0 {
- // 排序:优先级 > 负载率 > 最后使用时间
+ // 批量获取亲和客户端数量
+ s.populateAffinityCounts(ctx, routingAvailable, derefGroupID(groupID))
+
+ // 排序:优先级 > 负载率 > 亲和客户端数 > 最后使用时间
sort.SliceStable(routingAvailable, func(i, j int) bool {
a, b := routingAvailable[i], routingAvailable[j]
if a.account.Priority != b.account.Priority {
@@ -1536,6 +1603,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
+ if a.affinityCount != b.affinityCount {
+ return a.affinityCount < b.affinityCount
+ }
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
@@ -1561,10 +1631,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
+ if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && item.account.IsAffinityEnabled() {
+ _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, item.account.ID, ClientAffinityTTL)
+ }
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
- return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil)
+ return &AccountSelectionResult{
+ Account: item.account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
}
}
@@ -1577,12 +1654,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
- return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{
- AccountID: item.account.ID,
- MaxConcurrency: item.account.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- })
+ return &AccountSelectionResult{
+ Account: item.account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: item.account.ID,
+ MaxConcurrency: item.account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
}
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
@@ -1591,14 +1671,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
- // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
- if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
+ // ============ Layer 1.3: 用户亲和预处理(pinned_users 自动注入) ============
+ affinityFlow.preprocessPinnedUsers(accounts)
+
+ // ============ Layer 1.4: 客户端亲和调度(优先于粘性会话) ============
+ affinityHit := false
+ if affinityResult, hit, err := affinityFlow.trySelectAffinityAccount(); err != nil {
+ return nil, err
+ } else {
+ affinityHit = hit
+ if affinityResult != nil {
+ return affinityResult, nil
+ }
+ }
+
+ // ============ Layer 1.5: 粘性会话(仅在无模型路由配置 且 亲和未命中时生效) ============
+ if !affinityHit && len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
accountID := stickyAccountID
if accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
if ok {
// 检查账户是否需要清理粘性会话绑定
- // Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
@@ -1614,31 +1707,32 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
- // Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
- if s.cache != nil {
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
- }
- return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
+ return &AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
}
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
- // Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
- // Session limit full, continue to Layer 2
} else {
- return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
- AccountID: accountID,
- MaxConcurrency: account.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- })
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: accountID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
}
}
}
@@ -1697,9 +1791,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
- if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil {
- return nil, legacyErr
- } else if ok {
+ if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
+ if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && result.Account != nil && result.Account.IsAffinityEnabled() {
+ _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, result.Account.ID, ClientAffinityTTL)
+ }
return result, nil
}
} else {
@@ -1717,13 +1812,37 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
- // 分层过滤选择:优先级 → 负载率 → LRU
+ // 批量获取亲和客户端数量(用于均衡分配新客户端)
+ s.populateAffinityCounts(ctx, available, derefGroupID(groupID))
+
+ // 分层过滤选择:优先级 → 亲和三区 → 负载率 → 亲和客户端数 → LRU
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
- // 2. 取负载率最低的集合
+ // 2. 按亲和三区过滤:绿区优先 → 黄区降级 → 红区移除(在同优先级内)
+ candidates = classifyByAffinityZone(candidates)
+ if len(candidates) == 0 {
+ // 当前优先级组全部在红区,移除后回退到下一优先级组
+ minPri := available[0].account.Priority
+ for _, a := range available[1:] {
+ if a.account.Priority < minPri {
+ minPri = a.account.Priority
+ }
+ }
+ newAvailable := make([]accountWithLoad, 0, len(available))
+ for _, a := range available {
+ if a.account.Priority != minPri {
+ newAvailable = append(newAvailable, a)
+ }
+ }
+ available = newAvailable
+ continue
+ }
+ // 3. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
- // 3. LRU 选择最久未用的账号
+ // 3. 取亲和客户端数最少的集合
+ candidates = filterByMinAffinityCount(candidates)
+ // 4. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
@@ -1738,7 +1857,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
- return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil)
+ // 更新亲和关系
+ if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && selected.account.IsAffinityEnabled() {
+ _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, selected.account.ID, ClientAffinityTTL)
+ }
+ return &AccountSelectionResult{
+ Account: selected.account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
}
}
@@ -1761,17 +1888,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号
}
- return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{
- AccountID: acc.ID,
- MaxConcurrency: acc.Concurrency,
- Timeout: cfg.FallbackWaitTimeout,
- MaxWaiting: cfg.FallbackMaxWaiting,
- })
+ return &AccountSelectionResult{
+ Account: acc,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: acc.ID,
+ MaxConcurrency: acc.Concurrency,
+ Timeout: cfg.FallbackWaitTimeout,
+ MaxWaiting: cfg.FallbackMaxWaiting,
+ },
+ }, nil
}
return nil, ErrNoAvailableAccounts
}
-func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) {
+func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
@@ -1786,15 +1916,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
}
- selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil)
- if err != nil {
- return nil, false, err
- }
- return selection, true, nil
+ return &AccountSelectionResult{
+ Account: acc,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, true
}
}
- return nil, false, nil
+ return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
@@ -1939,6 +2069,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
+ if platform == PlatformSora {
+ return s.listSoraSchedulableAccounts(ctx, groupID)
+ }
if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
@@ -2035,6 +2168,53 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
+func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
+ const useMixed = false
+
+ var accounts []Account
+ var err error
+ if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
+ accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
+ } else if groupID != nil {
+ accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
+ } else {
+ accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
+ }
+ if err != nil {
+ slog.Debug("account_scheduling_list_failed",
+ "group_id", derefGroupID(groupID),
+ "platform", PlatformSora,
+ "error", err)
+ return nil, useMixed, err
+ }
+
+ filtered := make([]Account, 0, len(accounts))
+ for _, acc := range accounts {
+ if acc.Platform != PlatformSora {
+ continue
+ }
+ if !s.isSoraAccountSchedulable(&acc) {
+ continue
+ }
+ filtered = append(filtered, acc)
+ }
+ slog.Debug("account_scheduling_list_sora",
+ "group_id", derefGroupID(groupID),
+ "platform", PlatformSora,
+ "raw_count", len(accounts),
+ "filtered_count", len(filtered))
+ for _, acc := range filtered {
+ slog.Debug("account_scheduling_account_detail",
+ "account_id", acc.ID,
+ "name", acc.Name,
+ "platform", acc.Platform,
+ "type", acc.Type,
+ "status", acc.Status,
+ "tls_fingerprint", acc.IsTLSFingerprintEnabled())
+ }
+ return filtered, useMixed, nil
+}
+
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
@@ -2059,10 +2239,33 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
+func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
+ return s.soraUnschedulableReason(account) == ""
+}
+
+func (s *GatewayService) soraUnschedulableReason(account *Account) string {
+ if account == nil {
+ return "account_nil"
+ }
+ if account.Status != StatusActive {
+ return fmt.Sprintf("status=%s", account.Status)
+ }
+ if !account.Schedulable {
+ return "schedulable=false"
+ }
+ if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
+ return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
+ }
+ return ""
+}
+
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
if account == nil {
return false
}
+ if account.Platform == PlatformSora {
+ return s.isSoraAccountSchedulable(account)
+ }
return account.IsSchedulable()
}
@@ -2070,6 +2273,12 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
if account == nil {
return false
}
+ if account.Platform == PlatformSora {
+ if !s.isSoraAccountSchedulable(account) {
+ return false
+ }
+ return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
+ }
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
}
@@ -2409,31 +2618,34 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID)
}
-func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
- if account == nil || s.schedulerSnapshot == nil {
- return account, nil
+// populateAffinityCounts 批量获取账号的亲和客户端数量并填入 accountWithLoad 切片。
+// 仅当存在开启了客户端亲和的账号时才查询 Redis,否则跳过。
+func (s *GatewayService) populateAffinityCounts(ctx context.Context, accounts []accountWithLoad, groupID int64) {
+ if s.cache == nil || len(accounts) == 0 {
+ return
}
- hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
+ // 快速检查:是否有任何账号开启了亲和
+ hasAffinity := false
+ for _, acc := range accounts {
+ if acc.account.IsAffinityEnabled() {
+ hasAffinity = true
+ break
+ }
+ }
+ if !hasAffinity {
+ return
+ }
+ accountIDs := make([]int64, len(accounts))
+ for i, acc := range accounts {
+ accountIDs[i] = acc.account.ID
+ }
+ countMap, err := s.cache.GetAccountAffinityCountBatch(ctx, groupID, accountIDs, ClientAffinityTTL)
if err != nil {
- return nil, err
+ return // 查询失败不影响调度,affinityCount 保持 0
}
- if hydrated == nil {
- return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID)
+ for i := range accounts {
+ accounts[i].affinityCount = countMap[accounts[i].account.ID]
}
- return hydrated, nil
-}
-
-func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
- hydrated, err := s.hydrateSelectedAccount(ctx, account)
- if err != nil {
- return nil, err
- }
- return &AccountSelectionResult{
- Account: hydrated,
- Acquired: acquired,
- ReleaseFunc: release,
- WaitPlan: waitPlan,
- }, nil
}
// filterByMinPriority 过滤出优先级最小的账号集合
@@ -2476,6 +2688,64 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad {
return result
}
+// filterByMinAffinityCount 过滤出亲和客户端数最少的账号集合
+func filterByMinAffinityCount(accounts []accountWithLoad) []accountWithLoad {
+ if len(accounts) == 0 {
+ return accounts
+ }
+ minCount := accounts[0].affinityCount
+ for _, acc := range accounts[1:] {
+ if acc.affinityCount < minCount {
+ minCount = acc.affinityCount
+ }
+ }
+ result := make([]accountWithLoad, 0, len(accounts))
+ for _, acc := range accounts {
+ if acc.affinityCount == minCount {
+ result = append(result, acc)
+ }
+ }
+ return result
+}
+
+// classifyByAffinityZone 按亲和分区对候选账号进行分类。
+// 返回值:仅绿区账号(有绿区时),否则返回黄区账号。红区账号被移除。
+// 如果没有任何账号开启了亲和三区配置(即 affinity_base <= 0),则原样返回所有账号。
+func classifyByAffinityZone(accounts []accountWithLoad) []accountWithLoad {
+ if len(accounts) == 0 {
+ return accounts
+ }
+ // 快速检查:是否有任何账号配置了 affinity_base
+ hasZoneConfig := false
+ for _, acc := range accounts {
+ if acc.account.IsAffinityEnabled() && acc.account.GetAffinityBase() > 0 {
+ hasZoneConfig = true
+ break
+ }
+ }
+ if !hasZoneConfig {
+ return accounts
+ }
+
+ greens := make([]accountWithLoad, 0, len(accounts))
+ yellows := make([]accountWithLoad, 0, len(accounts))
+ for _, acc := range accounts {
+ zone := acc.account.GetAffinityZone(acc.affinityCount)
+ switch zone {
+ case AffinityZoneGreen:
+ greens = append(greens, acc)
+ case AffinityZoneYellow:
+ yellows = append(yellows, acc)
+ case AffinityZoneRed:
+ // 红区:移除,不参与调度
+ }
+ }
+ if len(greens) > 0 {
+ return greens
+ }
+ return yellows
+}
+
// selectByLRU 从集合中选择最久未用的账号
// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个
func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad {
@@ -2711,12 +2981,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
- // require_privacy_set: 获取分组信息
- var schedGroup *Group
- if groupID != nil && s.groupRepo != nil {
- schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
- }
-
var accounts []Account
accountsLoaded := false
@@ -2788,12 +3052,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
- // require_privacy_set: 跳过 privacy 未设置的账号并标记异常
- if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
- _ = s.accountRepo.SetError(ctx, acc.ID,
- fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
- continue
- }
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
@@ -2885,8 +3143,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
- // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查,
- // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -2899,12 +3155,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
- // require_privacy_set: 跳过 privacy 未设置的账号并标记异常
- if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
- _ = s.accountRepo.SetError(ctx, acc.ID,
- fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
- continue
- }
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
@@ -2971,12 +3221,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
- // require_privacy_set: 获取分组信息
- var schedGroup *Group
- if groupID != nil && s.groupRepo != nil {
- schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
- }
-
var accounts []Account
accountsLoaded := false
@@ -3044,12 +3288,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
- // require_privacy_set: 跳过 privacy 未设置的账号并标记异常
- if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
- _ = s.accountRepo.SetError(ctx, acc.ID,
- fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
- continue
- }
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
@@ -3143,7 +3381,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
- // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -3156,12 +3393,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
- // require_privacy_set: 跳过 privacy 未设置的账号并标记异常
- if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
- _ = s.accountRepo.SetError(ctx, acc.ID,
- fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
- continue
- }
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
@@ -3273,6 +3504,9 @@ func (s *GatewayService) logDetailedSelectionFailure(
stats.SampleMappingIDs,
stats.SampleRateLimitIDs,
)
+ if platform == PlatformSora {
+ s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
+ }
return stats
}
@@ -3329,7 +3563,11 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "excluded"}
}
if !s.isAccountSchedulableForSelection(acc) {
- return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
+ detail := "generic_unschedulable"
+ if acc.Platform == PlatformSora {
+ detail = s.soraUnschedulableReason(acc)
+ }
+ return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
}
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
return selectionFailureDiagnosis{
@@ -3353,6 +3591,57 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"}
}
+func (s *GatewayService) logSoraSelectionFailureDetails(
+ ctx context.Context,
+ groupID *int64,
+ sessionHash string,
+ requestedModel string,
+ accounts []Account,
+ excludedIDs map[int64]struct{},
+ allowMixedScheduling bool,
+) {
+ const maxLines = 30
+ logged := 0
+
+ for i := range accounts {
+ if logged >= maxLines {
+ break
+ }
+ acc := &accounts[i]
+ diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
+ if diagnosis.Category == "eligible" {
+ continue
+ }
+ detail := diagnosis.Detail
+ if detail == "" {
+ detail = "-"
+ }
+ logger.LegacyPrintf(
+ "service.gateway",
+ "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
+ derefGroupID(groupID),
+ requestedModel,
+ shortSessionHash(sessionHash),
+ acc.ID,
+ acc.Platform,
+ diagnosis.Category,
+ detail,
+ )
+ logged++
+ }
+ if len(accounts) > maxLines {
+ logger.LegacyPrintf(
+ "service.gateway",
+ "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
+ derefGroupID(groupID),
+ requestedModel,
+ shortSessionHash(sessionHash),
+ len(accounts),
+ logged,
+ )
+ }
+}
+
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil {
return true
@@ -3431,10 +3720,17 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
return mapAntigravityModel(account, requestedModel) != ""
}
+ if account.Platform == PlatformSora {
+ return s.isSoraModelSupportedByAccount(account, requestedModel)
+ }
if account.IsBedrock() {
_, ok := ResolveBedrockModelID(account, requestedModel)
return ok
}
+ // OpenAI 透传模式:仅替换认证,允许所有模型
+ if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
+ return true
+ }
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
@@ -3443,6 +3739,143 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
+func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
+ if account == nil {
+ return false
+ }
+ if strings.TrimSpace(requestedModel) == "" {
+ return true
+ }
+
+ // 先走原始精确/通配符匹配。
+ mapping := account.GetModelMapping()
+ if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
+ return true
+ }
+
+ aliases := buildSoraModelAliases(requestedModel)
+ if len(aliases) == 0 {
+ return false
+ }
+
+ hasSoraSelector := false
+ for pattern := range mapping {
+ if !isSoraModelSelector(pattern) {
+ continue
+ }
+ hasSoraSelector = true
+ if matchPatternAnyAlias(pattern, aliases) {
+ return true
+ }
+ }
+
+ // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
+ // 此时不应误拦截 Sora 模型请求。
+ if !hasSoraSelector {
+ return true
+ }
+
+ return false
+}
+
+func matchPatternAnyAlias(pattern string, aliases []string) bool {
+ normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
+ if normalizedPattern == "" {
+ return false
+ }
+ for _, alias := range aliases {
+ if matchWildcard(normalizedPattern, alias) {
+ return true
+ }
+ }
+ return false
+}
+
+func isSoraModelSelector(pattern string) bool {
+ p := strings.ToLower(strings.TrimSpace(pattern))
+ if p == "" {
+ return false
+ }
+
+ switch {
+ case strings.HasPrefix(p, "sora"),
+ strings.HasPrefix(p, "gpt-image"),
+ strings.HasPrefix(p, "prompt-enhance"),
+ strings.HasPrefix(p, "sy_"):
+ return true
+ }
+
+ return p == "video" || p == "image"
+}
+
+func buildSoraModelAliases(requestedModel string) []string {
+ modelID := strings.ToLower(strings.TrimSpace(requestedModel))
+ if modelID == "" {
+ return nil
+ }
+
+ aliases := make([]string, 0, 8)
+ addAlias := func(value string) {
+ v := strings.ToLower(strings.TrimSpace(value))
+ if v == "" {
+ return
+ }
+ for _, existing := range aliases {
+ if existing == v {
+ return
+ }
+ }
+ aliases = append(aliases, v)
+ }
+
+ addAlias(modelID)
+ cfg, ok := GetSoraModelConfig(modelID)
+ if ok {
+ addAlias(cfg.Model)
+ switch cfg.Type {
+ case "video":
+ addAlias("video")
+ addAlias("sora")
+ addAlias(soraVideoFamilyAlias(modelID))
+ case "image":
+ addAlias("image")
+ addAlias("gpt-image")
+ case "prompt_enhance":
+ addAlias("prompt-enhance")
+ }
+ return aliases
+ }
+
+ switch {
+ case strings.HasPrefix(modelID, "sora"):
+ addAlias("video")
+ addAlias("sora")
+ addAlias(soraVideoFamilyAlias(modelID))
+ case strings.HasPrefix(modelID, "gpt-image"):
+ addAlias("image")
+ addAlias("gpt-image")
+ case strings.HasPrefix(modelID, "prompt-enhance"):
+ addAlias("prompt-enhance")
+ default:
+ return nil
+ }
+
+ return aliases
+}
+
+func soraVideoFamilyAlias(modelID string) string {
+ switch {
+ case strings.HasPrefix(modelID, "sora2pro-hd"):
+ return "sora2pro-hd"
+ case strings.HasPrefix(modelID, "sora2pro"):
+ return "sora2pro"
+ case strings.HasPrefix(modelID, "sora2"):
+ return "sora2"
+ default:
+ return ""
+ }
+}
+
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
@@ -3719,86 +4152,6 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
return result
}
-// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages,
-// system 字段仅保留 Claude Code 标识提示词。
-// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词
-// 无法通过检测,因为后续内容仍为非 Claude Code 格式。
-// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。
-func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
- system = normalizeSystemParam(system)
-
- // 1. 提取原始 system prompt 文本
- var originalSystemText string
- switch v := system.(type) {
- case string:
- originalSystemText = strings.TrimSpace(v)
- case []any:
- var parts []string
- for _, item := range v {
- if m, ok := item.(map[string]any); ok {
- if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" {
- parts = append(parts, text)
- }
- }
- }
- originalSystemText = strings.Join(parts, "\n\n")
- }
-
- // 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致)
- // 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
- // 使用 string 格式会被 Anthropic 检测为第三方应用。
- claudeCodeSystemBlock := []map[string]any{
- {
- "type": "text",
- "text": claudeCodeSystemPrompt,
- "cache_control": map[string]string{"type": "ephemeral"},
- },
- }
- out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
- if !ok {
- logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
- return body
- }
-
- // 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头
- // 模型仍通过 messages 接收完整指令,保留客户端功能
- ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
- if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
- instrMsg, err1 := json.Marshal(map[string]any{
- "role": "user",
- "content": []map[string]any{
- {"type": "text", "text": "[System Instructions]\n" + originalSystemText},
- },
- })
- ackMsg, err2 := json.Marshal(map[string]any{
- "role": "assistant",
- "content": []map[string]any{
- {"type": "text", "text": "Understood. I will follow these instructions."},
- },
- })
- if err1 != nil || err2 != nil {
- logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection")
- return out
- }
-
- // 重建 messages 数组:[instruction, ack, ...originalMessages]
- items := [][]byte{instrMsg, ackMsg}
- messagesResult := gjson.GetBytes(out, "messages")
- if messagesResult.IsArray() {
- messagesResult.ForEach(func(_, msg gjson.Result) bool {
- items = append(items, []byte(msg.Raw))
- return true
- })
- }
-
- if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk {
- out = next
- }
- }
-
- return out
-}
-
type cacheControlPath struct {
path string
log string
@@ -3960,7 +4313,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil {
- policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model)
+ policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
if policy.blockErr != nil {
return nil, policy.blockErr
}
@@ -3990,24 +4343,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
- // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
+ // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
- systemRewritten := false
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
- body = rewriteSystemForNonClaudeCode(body, parsed.System)
- systemRewritten = true
+ body = injectClaudeCodePrompt(body, parsed.System)
}
- // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
- // 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
- // 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
- normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
+ normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
// metadata 透传开启时跳过 metadata 注入
- _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx)
+ _, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx)
if !mimicMPT {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true
@@ -4054,12 +4402,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, err
}
- // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
+ // 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
- if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
- proxyURL = account.Proxy.URL()
- }
+ proxyURL = account.Proxy.URL()
}
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
@@ -4468,6 +4814,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 处理正常响应
+ ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed)
// 触发上游接受回调(提前释放串行锁,不等流完成)
if parsed.OnUpstreamAccepted != nil {
@@ -5534,16 +5881,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
targetURL = validatedURL + "/v1/messages?beta=true"
}
- } else if account.IsCustomBaseURLEnabled() {
- customURL := account.GetCustomBaseURL()
- if customURL == "" {
- return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
- }
- validatedURL, err := s.validateUpstreamBaseURL(customURL)
- if err != nil {
- return nil, err
- }
- targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account)
}
clientHeaders := http.Header{}
@@ -5553,9 +5890,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
var fingerprint *Fingerprint
- enableFP, enableMPT, enableCCH := true, false, false
+ enableFP, enableMPT := true, false
if s.settingService != nil {
- enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
+ enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
}
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID)
@@ -5582,15 +5919,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
- // 同步 billing header cc_version 与实际发送的 User-Agent 版本
- if fingerprint != nil {
- body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
- }
- // CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
- if enableCCH {
- body = signBillingHeaderCCH(body)
- }
-
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -5631,8 +5959,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
- policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
+ policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
effectiveDropSet := mergeDropSets(policyFilterSet)
+ effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
if tokenType == "oauth" {
@@ -5643,16 +5972,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
- // Claude Code OAuth credentials are scoped to Claude Code.
- // Non-haiku models MUST include claude-code beta for Anthropic to recognize
- // this as a legitimate Claude Code request; without it, the request is
- // rejected as third-party ("out of extra usage").
- // Haiku models are exempt from third-party detection and don't need it.
+ // Match real Claude CLI traffic (per mitmproxy reports):
+ // messages requests typically use only oauth + interleaved-thinking.
+ // Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
- if !strings.Contains(strings.ToLower(modelID), "haiku") {
- requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
- }
- setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
+ setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
} else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
@@ -5672,15 +5996,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
- // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
- if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
- if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
- if parsed := ParseMetadataUserID(uid); parsed != nil {
- setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
- }
- }
- }
-
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
"url": req.URL.String(),
@@ -5875,7 +6190,7 @@ type betaPolicyResult struct {
}
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
-func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult {
+func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
if s.settingService == nil {
return betaPolicyResult{}
}
@@ -5890,11 +6205,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
- effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
- switch effectiveAction {
+ switch rule.Action {
case BetaPolicyActionBlock:
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
- msg := effectiveErrMsg
+ msg := rule.ErrorMessage
if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed"
}
@@ -5936,7 +6250,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet"
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
// evaluates on demand (one DB call).
-func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} {
+func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
if c != nil {
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
if fs, ok := v.(map[string]struct{}); ok {
@@ -5944,7 +6258,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
}
}
}
- return s.evaluateBetaPolicy(ctx, "", account, model).filterSet
+ return s.evaluateBetaPolicy(ctx, "", account).filterSet
}
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
@@ -5963,33 +6277,6 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
}
}
-// matchModelWhitelist checks if a model matches any pattern in the whitelist.
-// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching.
-func matchModelWhitelist(model string, whitelist []string) bool {
- for _, pattern := range whitelist {
- if matchModelPattern(pattern, model) {
- return true
- }
- }
- return false
-}
-
-// resolveRuleAction determines the effective action and error message for a rule given the request model.
-// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally.
-// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others.
-func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) {
- if len(rule.ModelWhitelist) == 0 {
- return rule.Action, rule.ErrorMessage
- }
- if matchModelWhitelist(model, rule.ModelWhitelist) {
- return rule.Action, rule.ErrorMessage
- }
- if rule.FallbackAction != "" {
- return rule.FallbackAction, rule.FallbackErrorMessage
- }
- return BetaPolicyActionPass, "" // default fallback: pass (fail-open)
-}
-
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
func droppedBetaSet(extra ...string) map[string]struct{} {
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
@@ -6036,7 +6323,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
modelID string,
) ([]string, error) {
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
- policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID)
+ policy := s.evaluateBetaPolicy(ctx, betaHeader, account)
if policy.blockErr != nil {
return nil, policy.blockErr
}
@@ -6048,7 +6335,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
// 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
// 如果不做此检查,block 规则会被绕过。
- if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil {
+ if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil {
return nil, blockErr
}
@@ -6057,7 +6344,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
-func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError {
+func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError {
if s.settingService == nil || len(tokens) == 0 {
return nil
}
@@ -6069,15 +6356,14 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
isBedrock := account.IsBedrock()
tokenSet := buildBetaTokenSet(tokens)
for _, rule := range settings.Rules {
- effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
- if effectiveAction != BetaPolicyActionBlock {
+ if rule.Action != BetaPolicyActionBlock {
continue
}
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
if _, present := tokenSet[rule.BetaToken]; present {
- msg := effectiveErrMsg
+ msg := rule.ErrorMessage
if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed"
}
@@ -6709,6 +6995,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
sawTerminalEvent := false
+ skipAccountTTLOverride := false
pendingEventLines := make([]string, 0, 4)
@@ -6770,17 +7057,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged
+ claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
+ if claudeMaxOutcome.Simulated {
+ skipAccountTTLOverride = true
+ }
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged
+ claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
+ if claudeMaxOutcome.Simulated {
+ skipAccountTTLOverride = true
+ }
}
}
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
- if account.IsCacheTTLOverrideEnabled() {
+ if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride {
overrideTarget := account.GetCacheTTLOverrideTarget()
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
@@ -7212,8 +7507,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
+ claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID)
+ if claudeMaxOutcome.Simulated {
+ body = rewriteClaudeUsageJSONBytes(body, response.Usage)
+ }
+
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
- if account.IsCacheTTLOverrideEnabled() {
+ if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated {
overrideTarget := account.GetCacheTTLOverrideTarget()
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
@@ -7279,6 +7579,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
+ ParsedRequest *ParsedRequest
APIKey *APIKey
User *User
Account *Account
@@ -7437,6 +7738,9 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
cmd.CacheReadTokens = usageLog.CacheReadTokens
cmd.ImageCount = usageLog.ImageCount
+ if usageLog.MediaType != nil {
+ cmd.MediaType = *usageLog.MediaType
+ }
if usageLog.ServiceTier != nil {
cmd.ServiceTier = *usageLog.ServiceTier
}
@@ -7592,6 +7896,8 @@ type recordUsageOpts struct {
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
+ // - Sora 媒体类型分支(image/video/prompt)
+ // - MediaType 字段写入使用日志
EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要)
@@ -7616,6 +7922,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
+ ParsedRequest: input.ParsedRequest,
EnableClaudePath: true,
})
}
@@ -7682,6 +7989,7 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
+// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
@@ -7699,9 +8007,21 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
result.Usage.InputTokens = 0
}
- // Cache TTL Override: 确保计费时 token 分类与账号设置一致
+ // Claude Max cache billing policy(仅 Claude 路径启用)
cacheTTLOverridden := false
- if account.IsCacheTTLOverrideEnabled() {
+ simulatedClaudeMax := false
+ if opts.EnableClaudePath {
+ var apiKeyGroup *Group
+ if apiKey != nil {
+ apiKeyGroup = apiKey.Group
+ }
+ claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID)
+ simulatedClaudeMax = claudeMaxOutcome.Simulated ||
+ (shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage))
+ }
+
+ // Cache TTL Override: 确保计费时 token 分类与账号设置一致
+ if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
@@ -7783,6 +8103,16 @@ func (s *GatewayService) calculateRecordUsageCost(
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
+ // Sora 媒体类型分支(仅 Claude 路径启用)
+ if opts.EnableClaudePath {
+ if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
+ return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
+ }
+ if result.MediaType == MediaTypePrompt {
+ return &CostBreakdown{}
+ }
+ }
+
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
@@ -7792,6 +8122,28 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
+// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
+func (s *GatewayService) calculateSoraMediaCost(
+ result *ForwardResult,
+ apiKey *APIKey,
+ billingModel string,
+ multiplier float64,
+) *CostBreakdown {
+ var soraConfig *SoraPriceConfig
+ if apiKey.Group != nil {
+ soraConfig = &SoraPriceConfig{
+ ImagePrice360: apiKey.Group.SoraImagePrice360,
+ ImagePrice540: apiKey.Group.SoraImagePrice540,
+ VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
+ VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
+ }
+ }
+ if result.MediaType == MediaTypeImage {
+ return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
+ }
+ return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
+}
+
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
@@ -7814,7 +8166,7 @@ func (s *GatewayService) calculateImageCost(
billingModel string,
multiplier float64,
) *CostBreakdown {
- if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
+ if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
@@ -7829,7 +8181,6 @@ func (s *GatewayService) calculateImageCost(
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
- Resolved: resolved,
})
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
@@ -7872,7 +8223,7 @@ func (s *GatewayService) calculateTokenCost(
var err error
// 优先尝试渠道定价 → CalculateCostUnified
- if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
+ if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
@@ -7882,7 +8233,6 @@ func (s *GatewayService) calculateTokenCost(
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
- Resolved: resolved,
})
} else if opts.LongContextThreshold > 0 {
// 长上下文双倍计费(如 Gemini 200K 阈值)
@@ -7940,12 +8290,13 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
- BillingMode: resolveBillingMode(result, cost),
+ BillingMode: resolveBillingMode(opts, result, cost),
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
+ MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
@@ -7969,7 +8320,13 @@ func (s *GatewayService) buildRecordUsageLog(
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
-func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
+// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
+func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
+ isSoraMedia := opts.EnableClaudePath &&
+ (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
+ if isSoraMedia {
+ return nil
+ }
var mode string
switch {
case cost != nil && cost.BillingMode != "":
@@ -7982,6 +8339,13 @@ func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
return &mode
}
+func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
+ if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
+ return &result.MediaType
+ }
+ return nil
+}
+
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil {
return &subscription.ID
@@ -8010,8 +8374,8 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
-// ResolveChannelMappingAndRestrict 解析渠道映射。
-// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。
+// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
+// 返回映射结果和是否被限制。
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
@@ -8042,9 +8406,7 @@ func billingModelForRestriction(source, requestedModel, channelMappedModel strin
return requestedModel
case BillingModelSourceUpstream:
return ""
- case BillingModelSourceChannelMapped:
- return channelMappedModel
- default:
+ default: // channel_mapped
return channelMappedModel
}
}
@@ -8076,11 +8438,7 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
- if err != nil {
- slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err)
- return false
- }
- if ch == nil || !ch.RestrictModels {
+ if err != nil || ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
@@ -8172,12 +8530,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return err
}
- // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
+ // 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
- if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
- proxyURL = account.Proxy.URL()
- }
+ proxyURL = account.Proxy.URL()
}
// 发送请求
@@ -8456,16 +8812,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
- } else if account.IsCustomBaseURLEnabled() {
- customURL := account.GetCustomBaseURL()
- if customURL == "" {
- return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
- }
- validatedURL, err := s.validateUpstreamBaseURL(customURL)
- if err != nil {
- return nil, err
- }
- targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account)
}
clientHeaders := http.Header{}
@@ -8475,9 +8821,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
- ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
+ ctEnableFP, ctEnableMPT := true, false
if s.settingService != nil {
- ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
+ ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
}
var ctFingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
@@ -8495,14 +8841,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
- // 同步 billing header cc_version 与实际发送的 User-Agent 版本
- if ctFingerprint != nil && ctEnableFP {
- body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
- }
- if ctEnableCCH {
- body = signBillingHeaderCCH(body)
- }
-
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -8543,7 +8881,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
- ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
+ ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
@@ -8579,15 +8917,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
- // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
- if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
- if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
- if parsed := ParseMetadataUserID(uid); parsed != nil {
- setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
- }
- }
- }
-
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
@@ -8609,19 +8938,6 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
-// buildCustomRelayURL 构建自定义中继转发 URL
-// 在 path 后附加 beta=true 和可选的 proxy 查询参数
-func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string {
- u := strings.TrimRight(baseURL, "/") + path + "?beta=true"
- if account.ProxyID != nil && account.Proxy != nil {
- proxyURL := account.Proxy.URL()
- if proxyURL != "" {
- u += "&proxy=" + url.QueryEscape(proxyURL)
- }
- }
- return u
-}
-
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
From 160903fce7841b99d761b30372f793bcbf448c5f Mon Sep 17 00:00:00 2001
From: erio
Date: Thu, 2 Apr 2026 13:36:58 +0800
Subject: [PATCH 04/88] fix: address review findings for channel restriction
refactoring
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Fix 7 stale comments still mentioning "限制检查" in handlers/services
- Make billingModelForRestriction explicitly list channel_mapped case
- Add slog.Warn for error swallowing in ResolveChannelMapping and
needsUpstreamChannelRestrictionCheck
- Document sticky session upstream check exemption
---
.../handler/gateway_handler_chat_completions.go | 2 +-
.../handler/gateway_handler_responses.go | 2 +-
.../internal/handler/gemini_v1beta_handler.go | 2 +-
.../internal/handler/openai_chat_completions.go | 2 +-
.../internal/handler/openai_gateway_handler.go | 2 +-
backend/internal/service/gateway_service.go | 17 +++++++++++++----
6 files changed, 18 insertions(+), 9 deletions(-)
diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go
index abe2a1e5..be267332 100644
--- a/backend/internal/handler/gateway_handler_chat_completions.go
+++ b/backend/internal/handler/gateway_handler_chat_completions.go
@@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
- // 解析渠道级模型映射 + 限制检查
+ // 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction
diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go
index cf877182..e908eb9e 100644
--- a/backend/internal/handler/gateway_handler_responses.go
+++ b/backend/internal/handler/gateway_handler_responses.go
@@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
- // 解析渠道级模型映射 + 限制检查
+ // 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction:
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index ff63bc7f..45b5842f 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
- // 解析渠道级模型映射 + 限制检查
+ // 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {
diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go
index ada401c9..991cbb91 100644
--- a/backend/internal/handler/openai_chat_completions.go
+++ b/backend/internal/handler/openai_chat_completions.go
@@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
- // 解析渠道级模型映射 + 限制检查
+ // 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if h.errorPassthroughService != nil {
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 2b081617..dda6d2e3 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -1118,7 +1118,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
- // 解析渠道级模型映射 + 限制检查
+ // 解析渠道级模型映射
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func()
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 33ab38f2..24f36113 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -3143,6 +3143,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
+ // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查,
+ // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -3381,6 +3383,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
+ // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -8374,8 +8377,8 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
-// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
-// 返回映射结果和是否被限制。
+// ResolveChannelMappingAndRestrict 解析渠道映射。
+// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
@@ -8406,7 +8409,9 @@ func billingModelForRestriction(source, requestedModel, channelMappedModel strin
return requestedModel
case BillingModelSourceUpstream:
return ""
- default: // channel_mapped
+ case BillingModelSourceChannelMapped:
+ return channelMappedModel
+ default:
return channelMappedModel
}
}
@@ -8438,7 +8443,11 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
- if err != nil || ch == nil || !ch.RestrictModels {
+ if err != nil {
+ slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err)
+ return false
+ }
+ if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
From e3748741257c2c6b96ced2c0c0284dd31cc5e358 Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 3 Apr 2026 13:54:18 +0800
Subject: [PATCH 05/88] feat(channel): improve cache strategy and add
restriction logging
- Change channel cache TTL from 60s to 10min (reduce unnecessary DB queries)
- Actively rebuild cache after CRUD instead of lazy invalidation
- Add slog.Warn logging for channel pricing restriction blocks (4 places)
---
backend/internal/service/channel_service.go | 378 ++++++++------------
backend/internal/service/gateway_service.go | 46 ++-
2 files changed, 183 insertions(+), 241 deletions(-)
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index 9667cb98..c6a249ef 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan
const (
channelCacheTTL = 10 * time.Minute
- channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
+ channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheDBTimeout = 10 * time.Second
)
@@ -197,8 +197,10 @@ func newEmptyChannelCache() *channelCache {
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
-// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
-// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
+// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
+// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
+// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
+// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
@@ -224,7 +226,8 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
-// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
+// antigravity 平台同时服务 Claude 和 Gemini 模型。
+// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
@@ -248,58 +251,40 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
}
-// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
-// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
-func (s *ChannelService) storeErrorCache() {
- errorCache := newEmptyChannelCache()
- errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
- s.cache.Store(errorCache)
-}
-
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
+ // 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel()
- channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
- if err != nil {
- return nil, err
- }
-
- cache := populateChannelCache(channels, groupPlatforms)
- s.cache.Store(cache)
- return cache, nil
-}
-
-// fetchChannelData 从数据库加载渠道列表和分组平台映射。
-func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
- channels, err := s.repo.ListAll(ctx)
+ channels, err := s.repo.ListAll(dbCtx)
if err != nil {
+ // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
- s.storeErrorCache()
- return nil, nil, fmt.Errorf("list all channels: %w", err)
+ errorCache := newEmptyChannelCache()
+ errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
+ s.cache.Store(errorCache)
+ return nil, fmt.Errorf("list all channels: %w", err)
}
+ // 收集所有 groupID,批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
-
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
- groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
+ groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
- s.storeErrorCache()
- return nil, nil, fmt.Errorf("get group platforms: %w", err)
+ errorCache := newEmptyChannelCache()
+ errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
+ s.cache.Store(errorCache)
+ return nil, fmt.Errorf("get group platforms: %w", err)
}
}
- return channels, groupPlatforms, nil
-}
-// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
-func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels))
@@ -308,6 +293,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
+
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid]
@@ -315,20 +301,33 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
expandMappingToCache(cache, ch, gid, platform)
}
}
- return cache
+
+ // 通配符条目保持配置顺序(最先匹配到优先)
+
+ s.cache.Store(cache)
+ return cache, nil
}
// invalidateCache 使缓存失效,让下次读取时自然重建
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
-// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
+// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
+// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
- return groupPlatform == pricingPlatform
+ if groupPlatform == pricingPlatform {
+ return true
+ }
+ if groupPlatform == PlatformAntigravity {
+ return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
+ }
+ return false
}
-// matchingPlatforms 返回分组平台对应的可匹配平台列表。
-// 各平台严格独立,只返回自身。
+// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
func matchingPlatforms(groupPlatform string) []string {
+ if groupPlatform == PlatformAntigravity {
+ return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
+ }
return []string{groupPlatform}
}
func (s *ChannelService) invalidateCache() {
@@ -365,8 +364,10 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
return ""
}
-// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
-// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
+// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
+// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
+// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
+// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
@@ -383,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
return nil
}
-// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。
+// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
@@ -441,7 +442,8 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
-// 各平台严格独立,只在本平台内查找定价。
+// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
+// 确保跨平台同名模型各自独立匹配。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
@@ -479,10 +481,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
- lk, err := s.lookupGroupChannel(ctx, groupID)
- if err != nil {
- slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
- }
+ lk, _ := s.lookupGroupChannel(ctx, groupID)
if lk == nil {
return false
}
@@ -525,7 +524,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
-// 只在本平台的定价列表中查找。
+// antigravity 分组依次尝试所有匹配平台的定价列表。
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels {
return false
@@ -553,91 +552,6 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
return newBody
}
-// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
-// Create 和 Update 共用此函数,避免重复。
-func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
- if err := validateNoConflictingModels(pricing); err != nil {
- return err
- }
- if err := validatePricingIntervals(pricing); err != nil {
- return err
- }
- if err := validateNoConflictingMappings(mapping); err != nil {
- return err
- }
- return validatePricingBillingMode(pricing)
-}
-
-// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
-func validatePricingBillingMode(pricing []ChannelModelPricing) error {
- for _, p := range pricing {
- if err := checkBillingModeRequirements(p); err != nil {
- return err
- }
- if err := checkPricesNotNegative(p); err != nil {
- return err
- }
- if err := checkIntervalsHavePrices(p); err != nil {
- return err
- }
- }
- return nil
-}
-
-func checkBillingModeRequirements(p ChannelModelPricing) error {
- if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
- if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
- return infraerrors.BadRequest(
- "BILLING_MODE_MISSING_PRICE",
- "per-request price or intervals required for per_request/image billing mode",
- )
- }
- }
- return nil
-}
-
-func checkPricesNotNegative(p ChannelModelPricing) error {
- checks := []struct {
- field string
- val *float64
- }{
- {"input_price", p.InputPrice},
- {"output_price", p.OutputPrice},
- {"cache_write_price", p.CacheWritePrice},
- {"cache_read_price", p.CacheReadPrice},
- {"image_output_price", p.ImageOutputPrice},
- {"per_request_price", p.PerRequestPrice},
- }
- for _, c := range checks {
- if c.val != nil && *c.val < 0 {
- return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
- }
- }
- return nil
-}
-
-func checkIntervalsHavePrices(p ChannelModelPricing) error {
- for _, iv := range p.Intervals {
- if iv.InputPrice == nil && iv.OutputPrice == nil &&
- iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
- iv.PerRequestPrice == nil {
- return infraerrors.BadRequest(
- "INTERVAL_MISSING_PRICE",
- fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
- iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
- )
- }
- }
- return nil
-}
-
-func formatMaxTokens(max *int) string {
- if max == nil {
- return "∞"
- }
- return fmt.Sprintf("%d", *max)
-}
-
// --- CRUD ---
// Create 创建渠道
@@ -650,8 +564,15 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
return nil, ErrChannelExists
}
- if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
- return nil, err
+ // 检查分组冲突
+ if len(input.GroupIDs) > 0 {
+ conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
+ if err != nil {
+ return nil, fmt.Errorf("check group conflicts: %w", err)
+ }
+ if len(conflicting) > 0 {
+ return nil, ErrGroupAlreadyInChannel
+ }
}
channel := &Channel{
@@ -668,7 +589,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
channel.BillingModelSource = BillingModelSourceChannelMapped
}
- if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
+ if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
+ return nil, err
+ }
+ if err := validatePricingIntervals(channel.ModelPricing); err != nil {
+ return nil, err
+ }
+ if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
@@ -692,112 +619,102 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, fmt.Errorf("get channel: %w", err)
}
- if err := s.applyUpdateInput(ctx, channel, input); err != nil {
+ if input.Name != "" && input.Name != channel.Name {
+ exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
+ if err != nil {
+ return nil, fmt.Errorf("check channel exists: %w", err)
+ }
+ if exists {
+ return nil, ErrChannelExists
+ }
+ channel.Name = input.Name
+ }
+
+ if input.Description != nil {
+ channel.Description = *input.Description
+ }
+
+ if input.Status != "" {
+ channel.Status = input.Status
+ }
+
+ if input.RestrictModels != nil {
+ channel.RestrictModels = *input.RestrictModels
+ }
+
+ // 检查分组冲突
+ if input.GroupIDs != nil {
+ conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
+ if err != nil {
+ return nil, fmt.Errorf("check group conflicts: %w", err)
+ }
+ if len(conflicting) > 0 {
+ return nil, ErrGroupAlreadyInChannel
+ }
+ channel.GroupIDs = *input.GroupIDs
+ }
+
+ if input.ModelPricing != nil {
+ channel.ModelPricing = *input.ModelPricing
+ }
+
+ if input.ModelMapping != nil {
+ channel.ModelMapping = input.ModelMapping
+ }
+
+ if input.BillingModelSource != "" {
+ channel.BillingModelSource = input.BillingModelSource
+ }
+
+ if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
+ return nil, err
+ }
+ if err := validatePricingIntervals(channel.ModelPricing); err != nil {
+ return nil, err
+ }
+ if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
- if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
- return nil, err
+ // 先获取旧分组,Update 后旧分组关联已删除,无法再查到
+ var oldGroupIDs []int64
+ if s.authCacheInvalidator != nil {
+ var err2 error
+ oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
+ if err2 != nil {
+ slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
+ }
}
- oldGroupIDs := s.getOldGroupIDs(ctx, id)
-
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
- s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
+
+ // 失效新旧分组的 auth 缓存
+ if s.authCacheInvalidator != nil {
+ seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
+ for _, gid := range oldGroupIDs {
+ if _, ok := seen[gid]; !ok {
+ seen[gid] = struct{}{}
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
+ }
+ }
+ for _, gid := range channel.GroupIDs {
+ if _, ok := seen[gid]; !ok {
+ seen[gid] = struct{}{}
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
+ }
+ }
+ }
return s.repo.GetByID(ctx, id)
}
-// applyUpdateInput 将更新请求的字段应用到渠道实体上。
-func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
- if input.Name != "" && input.Name != channel.Name {
- exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
- if err != nil {
- return fmt.Errorf("check channel exists: %w", err)
- }
- if exists {
- return ErrChannelExists
- }
- channel.Name = input.Name
- }
- if input.Description != nil {
- channel.Description = *input.Description
- }
- if input.Status != "" {
- channel.Status = input.Status
- }
- if input.RestrictModels != nil {
- channel.RestrictModels = *input.RestrictModels
- }
- if input.GroupIDs != nil {
- if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
- return err
- }
- channel.GroupIDs = *input.GroupIDs
- }
- if input.ModelPricing != nil {
- channel.ModelPricing = *input.ModelPricing
- }
- if input.ModelMapping != nil {
- channel.ModelMapping = input.ModelMapping
- }
- if input.BillingModelSource != "" {
- channel.BillingModelSource = input.BillingModelSource
- }
- return nil
-}
-
-// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
-// channelID 为当前渠道 ID(Create 时传 0)。
-func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
- if len(groupIDs) == 0 {
- return nil
- }
- conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
- if err != nil {
- return fmt.Errorf("check group conflicts: %w", err)
- }
- if len(conflicting) > 0 {
- return ErrGroupAlreadyInChannel
- }
- return nil
-}
-
-// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
-func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
- if s.authCacheInvalidator == nil {
- return nil
- }
- oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
- if err != nil {
- slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
- }
- return oldGroupIDs
-}
-
-// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
-func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
- if s.authCacheInvalidator == nil {
- return
- }
- seen := make(map[int64]struct{})
- for _, ids := range groupIDSets {
- for _, gid := range ids {
- if _, ok := seen[gid]; ok {
- continue
- }
- seen[gid] = struct{}{}
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
- }
- }
-}
-
// Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
+ // 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
@@ -808,7 +725,12 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
}
s.invalidateCache()
- s.invalidateAuthCacheForGroups(ctx, groupIDs)
+
+ if s.authCacheInvalidator != nil {
+ for _, gid := range groupIDs {
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
+ }
+ }
return nil
}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 24f36113..31137fb4 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -1234,11 +1234,6 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
- // 渠道定价限制预检查(requested / channel_mapped 基准)
- if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
- return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
- }
-
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
@@ -1257,6 +1252,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
platform = PlatformAnthropic
}
+ // Claude Code 限制可能已将 groupID 解析为 fallback group,
+ // 渠道限制预检查必须使用解析后的分组。
+ if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
+ slog.Warn("channel pricing restriction blocked request",
+ "group_id", derefGroupID(groupID),
+ "model", requestedModel)
+ return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
+ }
+
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
@@ -1273,11 +1277,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
- // 渠道定价限制预检查(requested / channel_mapped 基准)
- if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
- return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
- }
-
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
@@ -1298,6 +1297,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
ctx = s.withGroupContext(ctx, group)
+ // Claude Code 限制可能已将 groupID 解析为 fallback group,
+ // 渠道限制预检查必须使用解析后的分组。
+ if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
+ slog.Warn("channel pricing restriction blocked request",
+ "group_id", derefGroupID(groupID),
+ "model", requestedModel)
+ return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
+ }
+
var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch
@@ -3004,7 +3012,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -3359,7 +3367,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
@@ -3383,7 +3391,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
- // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
@@ -8453,6 +8460,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return ch.BillingModelSource == BillingModelSourceUpstream
}
+// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
+// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
+// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
+func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
+ if groupID == nil {
+ return false
+ }
+ if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
+ return false
+ }
+ return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
+}
+
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
From 37c23eccfed36d59e70369e37d26b1eff5ac4a0f Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 5 Apr 2026 14:37:21 +0800
Subject: [PATCH 06/88] fix: gofmt formatting
---
backend/internal/service/channel_service.go | 2 +-
backend/internal/service/gateway_service.go | 689 +++-----------------
2 files changed, 82 insertions(+), 609 deletions(-)
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index c6a249ef..ec8310f6 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan
const (
channelCacheTTL = 10 * time.Minute
- channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
+ channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheDBTimeout = 10 * time.Second
)
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 31137fb4..5d285fb6 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -12,6 +12,7 @@ import (
"log/slog"
mathrand "math/rand"
"net/http"
+ "net/url"
"os"
"path/filepath"
"regexp"
@@ -41,8 +42,7 @@ import (
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
- stickySessionTTL = time.Hour // 粘性会话TTL
- ClientAffinityTTL = 24 * time.Hour // 客户端亲和TTL
+ stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 500 * 1024 * 1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
@@ -60,28 +60,14 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
-// MediaType 媒体类型常量
-const (
- MediaTypeImage = "image"
- MediaTypeVideo = "video"
- MediaTypePrompt = "prompt"
-)
-
-const (
- claudeMaxMessageOverheadTokens = 3
- claudeMaxBlockOverheadTokens = 1
- claudeMaxUnknownContentTokens = 4
-)
-
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
- account *Account
- loadInfo *AccountLoadInfo
- affinityCount int64 // 亲和客户端数量(反向索引),越少越优先
+ account *Account
+ loadInfo *AccountLoadInfo
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
@@ -345,10 +331,6 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
- // clientIDFromMetadataRegex 从 metadata.user_id 中提取客户端 ID(64位 hex)
- // 格式: user_{64位hex}_account_...
- clientIDFromMetadataRegex = regexp.MustCompile(`^user_([a-f0-9]{64})_account_`)
-
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
@@ -366,12 +348,6 @@ var ErrNoAvailableAccounts = errors.New("no available accounts")
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
-// ErrAffinityNoSwitch 表示亲和账号不可用且不允许切换到其他账号
-var ErrAffinityNoSwitch = errors.New("affinity account unavailable and switching is disabled")
-
-// ErrAffinityLimitExceeded 表示亲和客户端限制已达上限
-var ErrAffinityLimitExceeded = errors.New("affinity client limit exceeded")
-
// allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{
"accept": true,
@@ -393,6 +369,8 @@ var allowedHeaders = map[string]bool{
"user-agent": true,
"content-type": true,
"accept-encoding": true,
+ "x-claude-code-session-id": true,
+ "x-client-request-id": true,
}
// GatewayCache 定义网关服务的缓存操作接口。
@@ -413,39 +391,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
-
- // GetAffinityAccounts 获取亲和账号列表(按最近使用降序),同时清理过期成员
- GetAffinityAccounts(ctx context.Context, groupID int64, userID int64, clientID string, ttl time.Duration) ([]int64, error)
- // UpdateAffinity 添加/更新亲和关系(更新 score 为当前时间戳,刷新 key TTL)
- UpdateAffinity(ctx context.Context, groupID int64, userID int64, clientID string, accountID int64, ttl time.Duration) error
- // GetAccountAffinityCountBatch 批量获取账号的亲和成员数量(惰性清理过期成员)
- GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error)
- // GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和成员列表(去重)
- // accountGroups: map[accountID][]groupID
- // 返回值成员格式为 {userID}/{clientID}
- GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error)
- // GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间)
- GetAccountAffinityClientsWithScores(ctx context.Context, accountID int64, groupIDs []int64, ttl time.Duration) ([]AffinityClient, error)
- // ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)
- // 用于账号关闭亲和时立即清理旧绑定
- ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error
- // GetAffinityMultiCount 获取账号的多维度亲和计数
- // 返回: uniqueUsers, uniqueClients, perUserClients
- GetAffinityMultiCount(ctx context.Context, groupID int64, accountID int64, targetUserID int64, ttl time.Duration) (users, clients, perUser int64, err error)
-}
-
-// AffinityClient 亲和客户端信息(含用户 ID 和最后活跃时间)
-type AffinityClient struct {
- UserID int64 `json:"user_id"`
- ClientID string `json:"client_id"`
- LastActive time.Time `json:"last_active"`
-}
-
-// SortAffinityClients 按最后活跃时间降序排序
-func SortAffinityClients(clients []AffinityClient) {
- sort.Slice(clients, func(i, j int) bool {
- return clients[i].LastActive.After(clients[j].LastActive)
- })
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
@@ -516,20 +461,6 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
return false
}
-// extractClientIDFromMetadata 从 metadata.user_id 中提取客户端 ID(64位 hex)。
-// 格式: user_{64位hex}_account_..._session_...
-// 返回空字符串表示无法提取(非 Claude Code/Console 客户端)。
-func extractClientIDFromMetadata(metadataUserID string) string {
- if metadataUserID == "" {
- return ""
- }
- matches := clientIDFromMetadataRegex.FindStringSubmatch(metadataUserID)
- if matches == nil {
- return ""
- }
- return matches[1]
-}
-
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
@@ -572,10 +503,6 @@ type ForwardResult struct {
// 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
-
- // Sora 媒体字段
- MediaType string // image / video / prompt
- MediaURL string // 生成后的媒体地址(可选)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
@@ -1315,10 +1242,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
- // 提取客户端 ID(用于客户端亲和调度)
- affinityClientID := extractClientIDFromMetadata(metadataUserID)
- affinityUserID := sub2apiUserID
-
if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := ""
if group != nil {
@@ -1340,10 +1263,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err != nil {
return nil, err
}
- if shouldFilterAccountWithoutClientID(account, affinityClientID) {
- localExcluded[account.ID] = struct{}{}
- continue
- }
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
@@ -1405,7 +1324,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err != nil {
return nil, err
}
- accounts = filterAccountsWithoutClientID(accounts, affinityClientID)
if len(accounts) == 0 {
return nil, ErrNoAvailableAccounts
}
@@ -1424,19 +1342,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
_, excluded := excludedIDs[accountID]
return excluded
}
- affinityFlow := newGatewayAffinityFlow(
- s,
- ctx,
- groupID,
- sessionHash,
- requestedModel,
- affinityClientID,
- affinityUserID,
- platform,
- useMixed,
- accountByID,
- isExcluded,
- )
// 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64
@@ -1599,10 +1504,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if len(routingAvailable) > 0 {
- // 批量获取亲和客户端数量
- s.populateAffinityCounts(ctx, routingAvailable, derefGroupID(groupID))
-
- // 排序:优先级 > 负载率 > 亲和客户端数 > 最后使用时间
+ // 排序:优先级 > 负载率 > 最后使用时间
sort.SliceStable(routingAvailable, func(i, j int) bool {
a, b := routingAvailable[i], routingAvailable[j]
if a.account.Priority != b.account.Priority {
@@ -1611,9 +1513,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
- if a.affinityCount != b.affinityCount {
- return a.affinityCount < b.affinityCount
- }
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
@@ -1639,9 +1538,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
- if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && item.account.IsAffinityEnabled() {
- _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, item.account.ID, ClientAffinityTTL)
- }
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
@@ -1679,22 +1575,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
- // ============ Layer 1.3: 用户亲和预处理(pinned_users 自动注入) ============
- affinityFlow.preprocessPinnedUsers(accounts)
-
- // ============ Layer 1.4: 客户端亲和调度(优先于粘性会话) ============
- affinityHit := false
- if affinityResult, hit, err := affinityFlow.trySelectAffinityAccount(); err != nil {
- return nil, err
- } else {
- affinityHit = hit
- if affinityResult != nil {
- return affinityResult, nil
- }
- }
-
- // ============ Layer 1.5: 粘性会话(仅在无模型路由配置 且 亲和未命中时生效) ============
- if !affinityHit && len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
+ // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
+ if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
accountID := stickyAccountID
if accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
@@ -1800,9 +1682,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
- if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && result.Account != nil && result.Account.IsAffinityEnabled() {
- _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, result.Account.ID, ClientAffinityTTL)
- }
return result, nil
}
} else {
@@ -1820,37 +1699,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
- // 批量获取亲和客户端数量(用于均衡分配新客户端)
- s.populateAffinityCounts(ctx, available, derefGroupID(groupID))
-
- // 分层过滤选择:优先级 → 亲和三区 → 负载率 → 亲和客户端数 → LRU
+ // 分层过滤选择:优先级 → 负载率 → LRU
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
- // 2. 按亲和三区过滤:绿区优先 → 黄区降级 → 红区移除(在同优先级内)
- candidates = classifyByAffinityZone(candidates)
- if len(candidates) == 0 {
- // 当前优先级组全部在红区,移除后回退到下一优先级组
- minPri := available[0].account.Priority
- for _, a := range available[1:] {
- if a.account.Priority < minPri {
- minPri = a.account.Priority
- }
- }
- newAvailable := make([]accountWithLoad, 0, len(available))
- for _, a := range available {
- if a.account.Priority != minPri {
- newAvailable = append(newAvailable, a)
- }
- }
- available = newAvailable
- continue
- }
- // 3. 取负载率最低的集合
+ // 2. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
- // 3. 取亲和客户端数最少的集合
- candidates = filterByMinAffinityCount(candidates)
- // 4. LRU 选择最久未用的账号
+ // 3. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
@@ -1865,10 +1720,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
- // 更新亲和关系
- if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && selected.account.IsAffinityEnabled() {
- _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, selected.account.ID, ClientAffinityTTL)
- }
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
@@ -2077,9 +1928,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
- if platform == PlatformSora {
- return s.listSoraSchedulableAccounts(ctx, groupID)
- }
if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
@@ -2176,53 +2024,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
-func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
- const useMixed = false
-
- var accounts []Account
- var err error
- if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
- accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
- } else if groupID != nil {
- accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
- } else {
- accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
- }
- if err != nil {
- slog.Debug("account_scheduling_list_failed",
- "group_id", derefGroupID(groupID),
- "platform", PlatformSora,
- "error", err)
- return nil, useMixed, err
- }
-
- filtered := make([]Account, 0, len(accounts))
- for _, acc := range accounts {
- if acc.Platform != PlatformSora {
- continue
- }
- if !s.isSoraAccountSchedulable(&acc) {
- continue
- }
- filtered = append(filtered, acc)
- }
- slog.Debug("account_scheduling_list_sora",
- "group_id", derefGroupID(groupID),
- "platform", PlatformSora,
- "raw_count", len(accounts),
- "filtered_count", len(filtered))
- for _, acc := range filtered {
- slog.Debug("account_scheduling_account_detail",
- "account_id", acc.ID,
- "name", acc.Name,
- "platform", acc.Platform,
- "type", acc.Type,
- "status", acc.Status,
- "tls_fingerprint", acc.IsTLSFingerprintEnabled())
- }
- return filtered, useMixed, nil
-}
-
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
@@ -2247,33 +2048,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
-func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
- return s.soraUnschedulableReason(account) == ""
-}
-
-func (s *GatewayService) soraUnschedulableReason(account *Account) string {
- if account == nil {
- return "account_nil"
- }
- if account.Status != StatusActive {
- return fmt.Sprintf("status=%s", account.Status)
- }
- if !account.Schedulable {
- return "schedulable=false"
- }
- if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
- return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
- }
- return ""
-}
-
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
if account == nil {
return false
}
- if account.Platform == PlatformSora {
- return s.isSoraAccountSchedulable(account)
- }
return account.IsSchedulable()
}
@@ -2281,12 +2059,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
if account == nil {
return false
}
- if account.Platform == PlatformSora {
- if !s.isSoraAccountSchedulable(account) {
- return false
- }
- return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
- }
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
}
@@ -2626,36 +2398,6 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID)
}
-// populateAffinityCounts 批量获取账号的亲和客户端数量并填入 accountWithLoad 切片。
-// 仅当存在开启了客户端亲和的账号时才查询 Redis,否则跳过。
-func (s *GatewayService) populateAffinityCounts(ctx context.Context, accounts []accountWithLoad, groupID int64) {
- if s.cache == nil || len(accounts) == 0 {
- return
- }
- // 快速检查:是否有任何账号开启了亲和
- hasAffinity := false
- for _, acc := range accounts {
- if acc.account.IsAffinityEnabled() {
- hasAffinity = true
- break
- }
- }
- if !hasAffinity {
- return
- }
- accountIDs := make([]int64, len(accounts))
- for i, acc := range accounts {
- accountIDs[i] = acc.account.ID
- }
- countMap, err := s.cache.GetAccountAffinityCountBatch(ctx, groupID, accountIDs, ClientAffinityTTL)
- if err != nil {
- return // 查询失败不影响调度,affinityCount 保持 0
- }
- for i := range accounts {
- accounts[i].affinityCount = countMap[accounts[i].account.ID]
- }
-}
-
// filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
@@ -2696,64 +2438,6 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad {
return result
}
-// filterByMinAffinityCount 过滤出亲和客户端数最少的账号集合
-func filterByMinAffinityCount(accounts []accountWithLoad) []accountWithLoad {
- if len(accounts) == 0 {
- return accounts
- }
- minCount := accounts[0].affinityCount
- for _, acc := range accounts[1:] {
- if acc.affinityCount < minCount {
- minCount = acc.affinityCount
- }
- }
- result := make([]accountWithLoad, 0, len(accounts))
- for _, acc := range accounts {
- if acc.affinityCount == minCount {
- result = append(result, acc)
- }
- }
- return result
-}
-
-// classifyByAffinityZone 按亲和分区对候选账号进行分类。
-// 返回值:仅绿区账号(有绿区时),否则返回黄区账号。红区账号被移除。
-// 如果没有任何账号开启了亲和三区配置(即 affinity_base <= 0),则原样返回所有账号。
-func classifyByAffinityZone(accounts []accountWithLoad) []accountWithLoad {
- if len(accounts) == 0 {
- return accounts
- }
- // 快速检查:是否有任何账号配置了 affinity_base
- hasZoneConfig := false
- for _, acc := range accounts {
- if acc.account.IsAffinityEnabled() && acc.account.GetAffinityBase() > 0 {
- hasZoneConfig = true
- break
- }
- }
- if !hasZoneConfig {
- return accounts
- }
-
- greens := make([]accountWithLoad, 0, len(accounts))
- yellows := make([]accountWithLoad, 0, len(accounts))
- for _, acc := range accounts {
- zone := acc.account.GetAffinityZone(acc.affinityCount)
- switch zone {
- case AffinityZoneGreen:
- greens = append(greens, acc)
- case AffinityZoneYellow:
- yellows = append(yellows, acc)
- case AffinityZoneRed:
- // 红区:移除,不参与调度
- }
- }
- if len(greens) > 0 {
- return greens
- }
- return yellows
-}
-
// selectByLRU 从集合中选择最久未用的账号
// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个
func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad {
@@ -3514,9 +3198,6 @@ func (s *GatewayService) logDetailedSelectionFailure(
stats.SampleMappingIDs,
stats.SampleRateLimitIDs,
)
- if platform == PlatformSora {
- s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
- }
return stats
}
@@ -3574,9 +3255,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
}
if !s.isAccountSchedulableForSelection(acc) {
detail := "generic_unschedulable"
- if acc.Platform == PlatformSora {
- detail = s.soraUnschedulableReason(acc)
- }
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
}
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
@@ -3601,57 +3279,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"}
}
-func (s *GatewayService) logSoraSelectionFailureDetails(
- ctx context.Context,
- groupID *int64,
- sessionHash string,
- requestedModel string,
- accounts []Account,
- excludedIDs map[int64]struct{},
- allowMixedScheduling bool,
-) {
- const maxLines = 30
- logged := 0
-
- for i := range accounts {
- if logged >= maxLines {
- break
- }
- acc := &accounts[i]
- diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
- if diagnosis.Category == "eligible" {
- continue
- }
- detail := diagnosis.Detail
- if detail == "" {
- detail = "-"
- }
- logger.LegacyPrintf(
- "service.gateway",
- "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
- derefGroupID(groupID),
- requestedModel,
- shortSessionHash(sessionHash),
- acc.ID,
- acc.Platform,
- diagnosis.Category,
- detail,
- )
- logged++
- }
- if len(accounts) > maxLines {
- logger.LegacyPrintf(
- "service.gateway",
- "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
- derefGroupID(groupID),
- requestedModel,
- shortSessionHash(sessionHash),
- len(accounts),
- logged,
- )
- }
-}
-
+// GetAccessToken 获取账号凭证
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil {
return true
@@ -3730,9 +3358,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
return mapAntigravityModel(account, requestedModel) != ""
}
- if account.Platform == PlatformSora {
- return s.isSoraModelSupportedByAccount(account, requestedModel)
- }
if account.IsBedrock() {
_, ok := ResolveBedrockModelID(account, requestedModel)
return ok
@@ -3749,143 +3374,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
-func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
- if account == nil {
- return false
- }
- if strings.TrimSpace(requestedModel) == "" {
- return true
- }
-
- // 先走原始精确/通配符匹配。
- mapping := account.GetModelMapping()
- if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
- return true
- }
-
- aliases := buildSoraModelAliases(requestedModel)
- if len(aliases) == 0 {
- return false
- }
-
- hasSoraSelector := false
- for pattern := range mapping {
- if !isSoraModelSelector(pattern) {
- continue
- }
- hasSoraSelector = true
- if matchPatternAnyAlias(pattern, aliases) {
- return true
- }
- }
-
- // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
- // 此时不应误拦截 Sora 模型请求。
- if !hasSoraSelector {
- return true
- }
-
- return false
-}
-
-func matchPatternAnyAlias(pattern string, aliases []string) bool {
- normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
- if normalizedPattern == "" {
- return false
- }
- for _, alias := range aliases {
- if matchWildcard(normalizedPattern, alias) {
- return true
- }
- }
- return false
-}
-
-func isSoraModelSelector(pattern string) bool {
- p := strings.ToLower(strings.TrimSpace(pattern))
- if p == "" {
- return false
- }
-
- switch {
- case strings.HasPrefix(p, "sora"),
- strings.HasPrefix(p, "gpt-image"),
- strings.HasPrefix(p, "prompt-enhance"),
- strings.HasPrefix(p, "sy_"):
- return true
- }
-
- return p == "video" || p == "image"
-}
-
-func buildSoraModelAliases(requestedModel string) []string {
- modelID := strings.ToLower(strings.TrimSpace(requestedModel))
- if modelID == "" {
- return nil
- }
-
- aliases := make([]string, 0, 8)
- addAlias := func(value string) {
- v := strings.ToLower(strings.TrimSpace(value))
- if v == "" {
- return
- }
- for _, existing := range aliases {
- if existing == v {
- return
- }
- }
- aliases = append(aliases, v)
- }
-
- addAlias(modelID)
- cfg, ok := GetSoraModelConfig(modelID)
- if ok {
- addAlias(cfg.Model)
- switch cfg.Type {
- case "video":
- addAlias("video")
- addAlias("sora")
- addAlias(soraVideoFamilyAlias(modelID))
- case "image":
- addAlias("image")
- addAlias("gpt-image")
- case "prompt_enhance":
- addAlias("prompt-enhance")
- }
- return aliases
- }
-
- switch {
- case strings.HasPrefix(modelID, "sora"):
- addAlias("video")
- addAlias("sora")
- addAlias(soraVideoFamilyAlias(modelID))
- case strings.HasPrefix(modelID, "gpt-image"):
- addAlias("image")
- addAlias("gpt-image")
- case strings.HasPrefix(modelID, "prompt-enhance"):
- addAlias("prompt-enhance")
- default:
- return nil
- }
-
- return aliases
-}
-
-func soraVideoFamilyAlias(modelID string) string {
- switch {
- case strings.HasPrefix(modelID, "sora2pro-hd"):
- return "sora2pro-hd"
- case strings.HasPrefix(modelID, "sora2pro"):
- return "sora2pro"
- case strings.HasPrefix(modelID, "sora2"):
- return "sora2"
- default:
- return ""
- }
-}
-
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
@@ -4412,10 +3900,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, err
}
- // 获取代理URL
+ // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
- proxyURL = account.Proxy.URL()
+ if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
+ proxyURL = account.Proxy.URL()
+ }
}
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
@@ -4824,7 +4314,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 处理正常响应
- ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed)
// 触发上游接受回调(提前释放串行锁,不等流完成)
if parsed.OnUpstreamAccepted != nil {
@@ -5891,6 +5380,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
targetURL = validatedURL + "/v1/messages?beta=true"
}
+ } else if account.IsCustomBaseURLEnabled() {
+ customURL := account.GetCustomBaseURL()
+ if customURL == "" {
+ return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
+ }
+ validatedURL, err := s.validateUpstreamBaseURL(customURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account)
}
clientHeaders := http.Header{}
@@ -6006,6 +5505,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
+ // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
+ if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
+ if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
+ if parsed := ParseMetadataUserID(uid); parsed != nil {
+ setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
+ }
+ }
+ }
+
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
"url": req.URL.String(),
@@ -7005,7 +6513,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
sawTerminalEvent := false
- skipAccountTTLOverride := false
pendingEventLines := make([]string, 0, 4)
@@ -7067,25 +6574,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged
- claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
- if claudeMaxOutcome.Simulated {
- skipAccountTTLOverride = true
- }
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged
- claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
- if claudeMaxOutcome.Simulated {
- skipAccountTTLOverride = true
- }
}
}
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
- if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride {
+ if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget()
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
@@ -7517,13 +7016,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
- claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID)
- if claudeMaxOutcome.Simulated {
- body = rewriteClaudeUsageJSONBytes(body, response.Usage)
- }
-
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
- if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated {
+ if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget()
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
@@ -7901,12 +7395,10 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct {
- // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
+ // ParsedRequest(可选,仅 Claude 路径传入)
ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑:
- // - Claude Max 缓存计费策略
- // - Sora 媒体类型分支(image/video/prompt)
// - MediaType 字段写入使用日志
EnableClaudePath bool
@@ -7998,8 +7490,6 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
-// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
-// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
@@ -8017,21 +7507,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
result.Usage.InputTokens = 0
}
- // Claude Max cache billing policy(仅 Claude 路径启用)
- cacheTTLOverridden := false
- simulatedClaudeMax := false
- if opts.EnableClaudePath {
- var apiKeyGroup *Group
- if apiKey != nil {
- apiKeyGroup = apiKey.Group
- }
- claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID)
- simulatedClaudeMax = claudeMaxOutcome.Simulated ||
- (shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage))
- }
-
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
- if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
+ cacheTTLOverridden := false
+ if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
@@ -8113,16 +7591,6 @@ func (s *GatewayService) calculateRecordUsageCost(
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
- // Sora 媒体类型分支(仅 Claude 路径启用)
- if opts.EnableClaudePath {
- if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
- return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
- }
- if result.MediaType == MediaTypePrompt {
- return &CostBreakdown{}
- }
- }
-
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
@@ -8132,28 +7600,6 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
-// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
-func (s *GatewayService) calculateSoraMediaCost(
- result *ForwardResult,
- apiKey *APIKey,
- billingModel string,
- multiplier float64,
-) *CostBreakdown {
- var soraConfig *SoraPriceConfig
- if apiKey.Group != nil {
- soraConfig = &SoraPriceConfig{
- ImagePrice360: apiKey.Group.SoraImagePrice360,
- ImagePrice540: apiKey.Group.SoraImagePrice540,
- VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
- VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
- }
- }
- if result.MediaType == MediaTypeImage {
- return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
- }
- return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
-}
-
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
@@ -8176,7 +7622,7 @@ func (s *GatewayService) calculateImageCost(
billingModel string,
multiplier float64,
) *CostBreakdown {
- if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
+ if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
@@ -8191,6 +7637,7 @@ func (s *GatewayService) calculateImageCost(
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
+ Resolved: resolved,
})
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
@@ -8233,7 +7680,7 @@ func (s *GatewayService) calculateTokenCost(
var err error
// 优先尝试渠道定价 → CalculateCostUnified
- if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
+ if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
@@ -8243,6 +7690,7 @@ func (s *GatewayService) calculateTokenCost(
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
+ Resolved: resolved,
})
} else if opts.LongContextThreshold > 0 {
// 长上下文双倍计费(如 Gemini 200K 阈值)
@@ -8330,13 +7778,7 @@ func (s *GatewayService) buildRecordUsageLog(
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
-// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
- isSoraMedia := opts.EnableClaudePath &&
- (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
- if isSoraMedia {
- return nil
- }
var mode string
switch {
case cost != nil && cost.BillingMode != "":
@@ -8350,9 +7792,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
}
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
- if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
- return &result.MediaType
- }
return nil
}
@@ -8559,10 +7998,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return err
}
- // 获取代理URL
+ // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
- proxyURL = account.Proxy.URL()
+ if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
+ proxyURL = account.Proxy.URL()
+ }
}
// 发送请求
@@ -8841,6 +8282,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
+ } else if account.IsCustomBaseURLEnabled() {
+ customURL := account.GetCustomBaseURL()
+ if customURL == "" {
+ return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
+ }
+ validatedURL, err := s.validateUpstreamBaseURL(customURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account)
}
clientHeaders := http.Header{}
@@ -8946,6 +8397,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
+ // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
+ if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
+ if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
+ if parsed := ParseMetadataUserID(uid); parsed != nil {
+ setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
+ }
+ }
+ }
+
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
@@ -8967,6 +8427,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
+// buildCustomRelayURL 构建自定义中继转发 URL
+// 在 path 后附加 beta=true 和可选的 proxy 查询参数
+func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string {
+ u := strings.TrimRight(baseURL, "/") + path + "?beta=true"
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL := account.Proxy.URL()
+ if proxyURL != "" {
+ u += "&proxy=" + url.QueryEscape(proxyURL)
+ }
+ }
+ return u
+}
+
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
From 794e81720834913c5fb3251a5c13fd88040c354a Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 5 Apr 2026 16:39:24 +0800
Subject: [PATCH 07/88] refactor: remove PaymentChannel, reuse upstream Channel
with features field
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Delete payment_channels table and PaymentChannel Ent schema
- Add `features` column to upstream channels table (migration 095)
- Add Features field to Channel struct, input types, handler request/response
- Payment user/admin handlers now use ChannelService directly
- Remove Channel CRUD from PaymentConfigService and admin payment routes
- Remove "渠道管理" tab from admin orders page (use /admin/channels)
---
backend/ent/client.go | 36 +-
backend/ent/intercept/intercept.go | 1 +
backend/ent/predicate/predicate.go | 1 +
.../internal/handler/admin/channel_handler.go | 6 +
backend/internal/handler/payment_handler.go | 1 +
backend/internal/repository/channel_repo.go | 20 +-
backend/internal/server/routes/payment.go | 13 +-
backend/internal/service/channel.go | 1 +
backend/internal/service/channel_service.go | 6 +
.../service/payment_config_service.go | 440 +++++++++++-------
backend/migrations/095_channel_features.sql | 2 +
11 files changed, 314 insertions(+), 213 deletions(-)
create mode 100644 backend/migrations/095_channel_features.sql
diff --git a/backend/ent/client.go b/backend/ent/client.go
index e52e015a..3da7acf8 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -333,10 +333,10 @@ func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -349,10 +349,10 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -4629,19 +4629,19 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type (
hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
- UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Hook
+ ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog,
+ PaymentOrder, PaymentProviderInstance, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
+ TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
- UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Interceptor
+ ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog,
+ PaymentOrder, PaymentProviderInstance, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
+ TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
}
)
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 8d8320bb..77d3e16e 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -336,6 +336,7 @@ func (f TraversePaymentAuditLog) Traverse(ctx context.Context, q ent.Query) erro
return fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q)
}
+
// The PaymentOrderFunc type is an adapter to allow the use of ordinary function as a Querier.
type PaymentOrderFunc func(context.Context, *ent.PaymentOrderQuery) (ent.Value, error)
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index ef551940..67f37c75 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -33,6 +33,7 @@ type IdempotencyRecord func(*sql.Selector)
// PaymentAuditLog is the predicate function for paymentauditlog builders.
type PaymentAuditLog func(*sql.Selector)
+
// PaymentOrder is the predicate function for paymentorder builders.
type PaymentOrder func(*sql.Selector)
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
index c92b35bb..d6022283 100644
--- a/backend/internal/handler/admin/channel_handler.go
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -33,6 +33,7 @@ type createChannelRequest struct {
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"`
+ Features string `json:"features"`
}
type updateChannelRequest struct {
@@ -44,6 +45,7 @@ type updateChannelRequest struct {
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"`
+ Features *string `json:"features"`
}
type channelModelPricingRequest struct {
@@ -78,6 +80,7 @@ type channelResponse struct {
Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
+ Features string `json:"features"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
@@ -122,6 +125,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Description: ch.Description,
Status: ch.Status,
RestrictModels: ch.RestrictModels,
+ Features: ch.Features,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
@@ -300,6 +304,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
+ Features: req.Features,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -332,6 +337,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
+ Features: req.Features,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 0425fc49..e01a2af1 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -7,6 +7,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go
index 49c2d8d9..710322fb 100644
--- a/backend/internal/repository/channel_repo.go
+++ b/backend/internal/repository/channel_repo.go
@@ -42,9 +42,9 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
return err
}
err = tx.QueryRowContext(ctx,
- `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
+ `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features) VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, created_at, updated_at`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -75,9 +75,9 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
ch := &service.Channel{}
var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at
FROM channels WHERE id = $1`, id,
- ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
+ ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
@@ -108,9 +108,9 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
return err
}
result, err := tx.ExecContext(ctx,
- `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
- WHERE id = $7`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
+ `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, updated_at = NOW()
+ WHERE id = $8`,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -204,7 +204,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
@@ -273,7 +273,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -285,7 +285,7 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
for rows.Next() {
var ch service.Channel
var modelMappingJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
index 6bf04679..828b68f3 100644
--- a/backend/internal/server/routes/payment.go
+++ b/backend/internal/server/routes/payment.go
@@ -26,7 +26,6 @@ func RegisterPaymentRoutes(
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.GET("/config", paymentHandler.GetPaymentConfig)
- authenticated.GET("/checkout-info", paymentHandler.GetCheckoutInfo)
authenticated.GET("/plans", paymentHandler.GetPlans)
authenticated.GET("/channels", paymentHandler.GetChannels)
authenticated.GET("/limits", paymentHandler.GetLimits)
@@ -34,7 +33,6 @@ func RegisterPaymentRoutes(
orders := authenticated.Group("/orders")
{
orders.POST("", paymentHandler.CreateOrder)
- orders.POST("/verify", paymentHandler.VerifyOrder)
orders.GET("/my", paymentHandler.GetMyOrders)
orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
@@ -42,19 +40,9 @@ func RegisterPaymentRoutes(
}
}
- // --- Public payment endpoints (no auth) ---
- // Payment result page needs to verify order status without login
- // (user session may have expired during provider redirect).
- public := v1.Group("/payment/public")
- {
- public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
- }
-
// --- Webhook endpoints (no auth) ---
webhook := v1.Group("/payment/webhook")
{
- // EasyPay sends GET callbacks with query params
- webhook.GET("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/easypay", webhookHandler.EasyPayNotify)
webhook.POST("/alipay", webhookHandler.AlipayNotify)
webhook.POST("/wxpay", webhookHandler.WxpayNotify)
@@ -82,6 +70,7 @@ func RegisterPaymentRoutes(
adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund)
}
+
// Subscription Plans
plans := adminGroup.Group("/plans")
{
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index 1697ed6f..eac81444 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -39,6 +39,7 @@ type Channel struct {
Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
+ Features string // 渠道特性描述(JSON 数组),用于支付页面展示
CreatedAt time.Time
UpdatedAt time.Time
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index ec8310f6..cdf94a4c 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -584,6 +584,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
+ Features: input.Features,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
@@ -641,6 +642,9 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
+ if input.Features != nil {
+ channel.Features = *input.Features
+ }
// 检查分组冲突
if input.GroupIDs != nil {
@@ -842,6 +846,7 @@ type CreateChannelInput struct {
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels bool
+ Features string
}
// UpdateChannelInput 更新渠道输入
@@ -854,4 +859,5 @@ type UpdateChannelInput struct {
ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string
RestrictModels *bool
+ Features *string
}
diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go
index 9042c3ab..dafe9afd 100644
--- a/backend/internal/service/payment_config_service.go
+++ b/backend/internal/service/payment_config_service.go
@@ -2,13 +2,16 @@ package service
import (
"context"
+ "encoding/json"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
@@ -23,8 +26,6 @@ const (
SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED"
SettingProductNamePrefix = "PRODUCT_NAME_PREFIX"
SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX"
- SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL"
- SettingHelpText = "PAYMENT_HELP_TEXT"
SettingCancelRateLimitOn = "CANCEL_RATE_LIMIT_ENABLED"
SettingCancelRateLimitMax = "CANCEL_RATE_LIMIT_MAX"
SettingCancelWindowSize = "CANCEL_RATE_LIMIT_WINDOW"
@@ -32,126 +33,91 @@ const (
SettingCancelWindowMode = "CANCEL_RATE_LIMIT_WINDOW_MODE"
)
-// Default values for payment configuration settings.
-const (
- defaultOrderTimeoutMin = 30
- defaultMaxPendingOrders = 3
-)
-
// PaymentConfig holds the payment system configuration.
type PaymentConfig struct {
- Enabled bool `json:"enabled"`
- MinAmount float64 `json:"min_amount"`
- MaxAmount float64 `json:"max_amount"`
- DailyLimit float64 `json:"daily_limit"`
- OrderTimeoutMin int `json:"order_timeout_minutes"`
- MaxPendingOrders int `json:"max_pending_orders"`
- EnabledTypes []string `json:"enabled_payment_types"`
- BalanceDisabled bool `json:"balance_disabled"`
- LoadBalanceStrategy string `json:"load_balance_strategy"`
- ProductNamePrefix string `json:"product_name_prefix"`
- ProductNameSuffix string `json:"product_name_suffix"`
- HelpImageURL string `json:"help_image_url"`
- HelpText string `json:"help_text"`
- StripePublishableKey string `json:"stripe_publishable_key,omitempty"`
-
- // Cancel rate limit settings
- CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"`
- CancelRateLimitMax int `json:"cancel_rate_limit_max"`
- CancelRateLimitWindow int `json:"cancel_rate_limit_window"`
- CancelRateLimitUnit string `json:"cancel_rate_limit_unit"`
- CancelRateLimitMode string `json:"cancel_rate_limit_window_mode"`
+ Enabled bool `json:"enabled"`
+ MinAmount float64 `json:"minAmount"`
+ MaxAmount float64 `json:"maxAmount"`
+ DailyLimit float64 `json:"dailyLimit"`
+ OrderTimeoutMin int `json:"orderTimeoutMinutes"`
+ MaxPendingOrders int `json:"maxPendingOrders"`
+ EnabledTypes []string `json:"enabledTypes"`
+ BalanceDisabled bool `json:"balanceDisabled"`
+ LoadBalanceStrategy string `json:"loadBalanceStrategy"`
+ ProductNamePrefix string `json:"productNamePrefix"`
+ ProductNameSuffix string `json:"productNameSuffix"`
}
// UpdatePaymentConfigRequest contains fields to update payment configuration.
type UpdatePaymentConfigRequest struct {
Enabled *bool `json:"enabled"`
- MinAmount *float64 `json:"min_amount"`
- MaxAmount *float64 `json:"max_amount"`
- DailyLimit *float64 `json:"daily_limit"`
- OrderTimeoutMin *int `json:"order_timeout_minutes"`
- MaxPendingOrders *int `json:"max_pending_orders"`
- EnabledTypes []string `json:"enabled_payment_types"`
- BalanceDisabled *bool `json:"balance_disabled"`
- LoadBalanceStrategy *string `json:"load_balance_strategy"`
- ProductNamePrefix *string `json:"product_name_prefix"`
- ProductNameSuffix *string `json:"product_name_suffix"`
- HelpImageURL *string `json:"help_image_url"`
- HelpText *string `json:"help_text"`
-
- // Cancel rate limit settings
- CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"`
- CancelRateLimitMax *int `json:"cancel_rate_limit_max"`
- CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
- CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
- CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
+ MinAmount *float64 `json:"minAmount"`
+ MaxAmount *float64 `json:"maxAmount"`
+ DailyLimit *float64 `json:"dailyLimit"`
+ OrderTimeoutMin *int `json:"orderTimeoutMinutes"`
+ MaxPendingOrders *int `json:"maxPendingOrders"`
+ EnabledTypes []string `json:"enabledTypes"`
+ BalanceDisabled *bool `json:"balanceDisabled"`
+ LoadBalanceStrategy *string `json:"loadBalanceStrategy"`
+ ProductNamePrefix *string `json:"productNamePrefix"`
+ ProductNameSuffix *string `json:"productNameSuffix"`
}
// MethodLimits holds per-payment-type limits.
type MethodLimits struct {
- PaymentType string `json:"payment_type"`
- FeeRate float64 `json:"fee_rate"`
- DailyLimit float64 `json:"daily_limit"`
- SingleMin float64 `json:"single_min"`
- SingleMax float64 `json:"single_max"`
-}
-
-// MethodLimitsResponse is the full response for the user-facing /limits API.
-// It includes per-method limits and the global widest range (union of all methods).
-type MethodLimitsResponse struct {
- Methods map[string]MethodLimits `json:"methods"`
- GlobalMin float64 `json:"global_min"` // 0 = no minimum
- GlobalMax float64 `json:"global_max"` // 0 = no maximum
+ PaymentType string `json:"paymentType"`
+ FeeRate float64 `json:"feeRate"`
+ DailyLimit float64 `json:"dailyLimit"`
+ SingleMin float64 `json:"singleMin"`
+ SingleMax float64 `json:"singleMax"`
}
type CreateProviderInstanceRequest struct {
- ProviderKey string `json:"provider_key"`
+ ProviderKey string `json:"providerKey"`
Name string `json:"name"`
Config map[string]string `json:"config"`
- SupportedTypes []string `json:"supported_types"`
+ SupportedTypes string `json:"supportedTypes"`
Enabled bool `json:"enabled"`
- PaymentMode string `json:"payment_mode"`
- SortOrder int `json:"sort_order"`
+ SortOrder int `json:"sortOrder"`
Limits string `json:"limits"`
- RefundEnabled bool `json:"refund_enabled"`
+ RefundEnabled bool `json:"refundEnabled"`
}
type UpdateProviderInstanceRequest struct {
Name *string `json:"name"`
Config map[string]string `json:"config"`
- SupportedTypes []string `json:"supported_types"`
+ SupportedTypes *string `json:"supportedTypes"`
Enabled *bool `json:"enabled"`
- PaymentMode *string `json:"payment_mode"`
- SortOrder *int `json:"sort_order"`
+ SortOrder *int `json:"sortOrder"`
Limits *string `json:"limits"`
- RefundEnabled *bool `json:"refund_enabled"`
+ RefundEnabled *bool `json:"refundEnabled"`
}
type CreatePlanRequest struct {
- GroupID int64 `json:"group_id"`
+ GroupID int64 `json:"groupId"`
Name string `json:"name"`
Description string `json:"description"`
Price float64 `json:"price"`
- OriginalPrice *float64 `json:"original_price"`
- ValidityDays int `json:"validity_days"`
- ValidityUnit string `json:"validity_unit"`
+ OriginalPrice *float64 `json:"originalPrice"`
+ ValidityDays int `json:"validityDays"`
+ ValidityUnit string `json:"validityUnit"`
Features string `json:"features"`
- ProductName string `json:"product_name"`
- ForSale bool `json:"for_sale"`
- SortOrder int `json:"sort_order"`
+ ProductName string `json:"productName"`
+ ForSale bool `json:"forSale"`
+ SortOrder int `json:"sortOrder"`
}
type UpdatePlanRequest struct {
- GroupID *int64 `json:"group_id"`
+ GroupID *int64 `json:"groupId"`
Name *string `json:"name"`
Description *string `json:"description"`
Price *float64 `json:"price"`
- OriginalPrice *float64 `json:"original_price"`
- ValidityDays *int `json:"validity_days"`
- ValidityUnit *string `json:"validity_unit"`
+ OriginalPrice *float64 `json:"originalPrice"`
+ ValidityDays *int `json:"validityDays"`
+ ValidityUnit *string `json:"validityUnit"`
Features *string `json:"features"`
- ProductName *string `json:"product_name"`
- ForSale *bool `json:"for_sale"`
- SortOrder *int `json:"sort_order"`
+ ProductName *string `json:"productName"`
+ ForSale *bool `json:"forSale"`
+ SortOrder *int `json:"sortOrder"`
}
// PaymentConfigService manages payment configuration and CRUD for
@@ -183,43 +149,29 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders,
SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingLoadBalanceStrategy,
SettingProductNamePrefix, SettingProductNameSuffix,
- SettingHelpImageURL, SettingHelpText,
- SettingCancelRateLimitOn, SettingCancelRateLimitMax,
- SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err)
}
- cfg := s.parsePaymentConfig(vals)
- // Load Stripe publishable key from the first enabled Stripe provider instance
- cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
- return cfg, nil
+ return s.parsePaymentConfig(vals), nil
}
func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig {
cfg := &PaymentConfig{
Enabled: vals[SettingPaymentEnabled] == "true",
MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1),
- MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0),
+ MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 99999999.99),
DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0),
- OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin),
- MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders),
+ OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], 30),
+ MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], 3),
BalanceDisabled: vals[SettingBalancePayDisabled] == "true",
LoadBalanceStrategy: vals[SettingLoadBalanceStrategy],
ProductNamePrefix: vals[SettingProductNamePrefix],
ProductNameSuffix: vals[SettingProductNameSuffix],
- HelpImageURL: vals[SettingHelpImageURL],
- HelpText: vals[SettingHelpText],
-
- CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true",
- CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10),
- CancelRateLimitWindow: pcParseInt(vals[SettingCancelWindowSize], 1),
- CancelRateLimitUnit: vals[SettingCancelWindowUnit],
- CancelRateLimitMode: vals[SettingCancelWindowMode],
}
if cfg.LoadBalanceStrategy == "" {
- cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
+ cfg.LoadBalanceStrategy = "round-robin"
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
for _, t := range strings.Split(raw, ",") {
@@ -232,100 +184,242 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
return cfg
}
-// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
-func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
- instances, err := s.entClient.PaymentProviderInstance.Query().
- Where(
- paymentproviderinstance.EnabledEQ(true),
- paymentproviderinstance.ProviderKeyEQ(payment.TypeStripe),
- ).Limit(1).All(ctx)
- if err != nil || len(instances) == 0 {
- return ""
- }
- cfg, err := s.decryptConfig(instances[0].Config)
- if err != nil || cfg == nil {
- return ""
- }
- return cfg[payment.ConfigKeyPublishableKey]
-}
-
// UpdatePaymentConfig updates the payment configuration settings.
-// NOTE: This function exceeds 30 lines because each field requires an independent
-// nil-check before serialisation — this is inherent to patch-style update patterns
-// and cannot be meaningfully decomposed without introducing unnecessary abstraction.
func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error {
- m := map[string]string{
- SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
- SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
- SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
- SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
- SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
- SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
- SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
- SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
- SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
- SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
- SettingHelpImageURL: derefStr(req.HelpImageURL),
- SettingHelpText: derefStr(req.HelpText),
- SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
- SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
- SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
- SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
- SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
+ m := make(map[string]string)
+ if req.Enabled != nil {
+ m[SettingPaymentEnabled] = strconv.FormatBool(*req.Enabled)
+ }
+ if req.MinAmount != nil {
+ m[SettingMinRechargeAmount] = strconv.FormatFloat(*req.MinAmount, 'f', 2, 64)
+ }
+ if req.MaxAmount != nil {
+ m[SettingMaxRechargeAmount] = strconv.FormatFloat(*req.MaxAmount, 'f', 2, 64)
+ }
+ if req.DailyLimit != nil {
+ m[SettingDailyRechargeLimit] = strconv.FormatFloat(*req.DailyLimit, 'f', 2, 64)
+ }
+ if req.OrderTimeoutMin != nil {
+ m[SettingOrderTimeoutMinutes] = strconv.Itoa(*req.OrderTimeoutMin)
+ }
+ if req.MaxPendingOrders != nil {
+ m[SettingMaxPendingOrders] = strconv.Itoa(*req.MaxPendingOrders)
}
if req.EnabledTypes != nil {
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
- } else {
- m[SettingEnabledPaymentTypes] = ""
+ }
+ if req.BalanceDisabled != nil {
+ m[SettingBalancePayDisabled] = strconv.FormatBool(*req.BalanceDisabled)
+ }
+ if req.LoadBalanceStrategy != nil {
+ m[SettingLoadBalanceStrategy] = *req.LoadBalanceStrategy
+ }
+ if req.ProductNamePrefix != nil {
+ m[SettingProductNamePrefix] = *req.ProductNamePrefix
+ }
+ if req.ProductNameSuffix != nil {
+ m[SettingProductNameSuffix] = *req.ProductNameSuffix
+ }
+ if len(m) == 0 {
+ return nil
}
return s.settingRepo.SetMultiple(ctx, m)
}
-func formatBoolOrEmpty(v *bool) string {
- if v == nil {
- return ""
- }
- return strconv.FormatBool(*v)
+// --- Provider Instance CRUD ---
+
+func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
+ return s.entClient.PaymentProviderInstance.Query().Order(paymentproviderinstance.BySortOrder()).All(ctx)
}
-func formatPositiveFloat(v *float64) string {
- if v == nil || *v <= 0 {
- return "" // empty → parsePaymentConfig uses default
+func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
+ enc, err := s.encryptConfig(req.Config)
+ if err != nil {
+ return nil, err
}
- return strconv.FormatFloat(*v, 'f', 2, 64)
+ return s.entClient.PaymentProviderInstance.Create().
+ SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
+ SetSupportedTypes(req.SupportedTypes).SetEnabled(req.Enabled).
+ SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
+ Save(ctx)
}
-func formatPositiveInt(v *int) string {
- if v == nil || *v <= 0 {
- return ""
+func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
+ u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
+ if req.Name != nil {
+ u.SetName(*req.Name)
}
- return strconv.Itoa(*v)
+ if req.Config != nil {
+ enc, err := s.encryptConfig(req.Config)
+ if err != nil {
+ return nil, err
+ }
+ u.SetConfig(enc)
+ }
+ if req.SupportedTypes != nil {
+ u.SetSupportedTypes(*req.SupportedTypes)
+ }
+ if req.Enabled != nil {
+ u.SetEnabled(*req.Enabled)
+ }
+ if req.SortOrder != nil {
+ u.SetSortOrder(*req.SortOrder)
+ }
+ if req.Limits != nil {
+ u.SetLimits(*req.Limits)
+ }
+ if req.RefundEnabled != nil {
+ u.SetRefundEnabled(*req.RefundEnabled)
+ }
+ return u.Save(ctx)
}
-func derefStr(v *string) string {
- if v == nil {
- return ""
- }
- return *v
+func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
+ return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
}
-func splitTypes(s string) []string {
- if s == "" {
- return nil
+func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
+ data, err := json.Marshal(cfg)
+ if err != nil {
+ return "", fmt.Errorf("marshal config: %w", err)
}
- parts := strings.Split(s, ",")
- result := make([]string, 0, len(parts))
- for _, p := range parts {
- p = strings.TrimSpace(p)
- if p != "" {
- result = append(result, p)
+ enc, err := payment.Encrypt(string(data), s.encryptionKey)
+ if err != nil {
+ return "", fmt.Errorf("encrypt config: %w", err)
+ }
+ return enc, nil
+}
+
+// --- Channel CRUD ---
+
+
+// --- Plan CRUD ---
+
+func (s *PaymentConfigService) ListPlans(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
+ return s.entClient.SubscriptionPlan.Query().Order(subscriptionplan.BySortOrder()).All(ctx)
+}
+
+func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
+ return s.entClient.SubscriptionPlan.Query().Where(subscriptionplan.ForSaleEQ(true)).Order(subscriptionplan.BySortOrder()).All(ctx)
+}
+
+func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
+ b := s.entClient.SubscriptionPlan.Create().
+ SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description).
+ SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit).
+ SetFeatures(req.Features).SetProductName(req.ProductName).
+ SetForSale(req.ForSale).SetSortOrder(req.SortOrder)
+ if req.OriginalPrice != nil {
+ b.SetOriginalPrice(*req.OriginalPrice)
+ }
+ return b.Save(ctx)
+}
+
+func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) {
+ u := s.entClient.SubscriptionPlan.UpdateOneID(id)
+ if req.GroupID != nil {
+ u.SetGroupID(*req.GroupID)
+ }
+ if req.Name != nil {
+ u.SetName(*req.Name)
+ }
+ if req.Description != nil {
+ u.SetDescription(*req.Description)
+ }
+ if req.Price != nil {
+ u.SetPrice(*req.Price)
+ }
+ if req.OriginalPrice != nil {
+ u.SetOriginalPrice(*req.OriginalPrice)
+ }
+ if req.ValidityDays != nil {
+ u.SetValidityDays(*req.ValidityDays)
+ }
+ if req.ValidityUnit != nil {
+ u.SetValidityUnit(*req.ValidityUnit)
+ }
+ if req.Features != nil {
+ u.SetFeatures(*req.Features)
+ }
+ if req.ProductName != nil {
+ u.SetProductName(*req.ProductName)
+ }
+ if req.ForSale != nil {
+ u.SetForSale(*req.ForSale)
+ }
+ if req.SortOrder != nil {
+ u.SetSortOrder(*req.SortOrder)
+ }
+ return u.Save(ctx)
+}
+
+func (s *PaymentConfigService) DeletePlan(ctx context.Context, id int64) error {
+ return s.entClient.SubscriptionPlan.DeleteOneID(id).Exec(ctx)
+}
+
+// GetPlan returns a subscription plan by ID.
+func (s *PaymentConfigService) GetPlan(ctx context.Context, id int64) (*dbent.SubscriptionPlan, error) {
+ plan, err := s.entClient.SubscriptionPlan.Get(ctx, id)
+ if err != nil {
+ return nil, infraerrors.NotFound("PLAN_NOT_FOUND", "subscription plan not found")
+ }
+ return plan, nil
+}
+
+// GetMethodLimits returns per-payment-type limits from enabled provider instances.
+func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("query provider instances: %w", err)
+ }
+ result := make([]MethodLimits, 0, len(types))
+ for _, pt := range types {
+ ml := MethodLimits{PaymentType: pt}
+ for _, inst := range instances {
+ if !pcInstanceSupportsType(inst, pt) {
+ continue
+ }
+ pcApplyInstanceLimits(inst, pt, &ml)
+ }
+ result = append(result, ml)
+ }
+ return result, nil
+}
+
+func pcInstanceSupportsType(inst *dbent.PaymentProviderInstance, pt string) bool {
+ if inst.SupportedTypes == "" {
+ return true
+ }
+ for _, t := range strings.Split(inst.SupportedTypes, ",") {
+ if strings.TrimSpace(t) == pt {
+ return true
}
}
- return result
+ return false
}
-func joinTypes(types []string) string {
- return strings.Join(types, ",")
+func pcApplyInstanceLimits(inst *dbent.PaymentProviderInstance, pt string, ml *MethodLimits) {
+ if inst.Limits == "" {
+ return
+ }
+ var limits payment.InstanceLimits
+ if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil {
+ return
+ }
+ cl, ok := limits[pt]
+ if !ok {
+ return
+ }
+ if cl.DailyLimit > 0 && (ml.DailyLimit == 0 || cl.DailyLimit < ml.DailyLimit) {
+ ml.DailyLimit = cl.DailyLimit
+ }
+ if cl.SingleMin > 0 && (ml.SingleMin == 0 || cl.SingleMin > ml.SingleMin) {
+ ml.SingleMin = cl.SingleMin
+ }
+ if cl.SingleMax > 0 && (ml.SingleMax == 0 || cl.SingleMax < ml.SingleMax) {
+ ml.SingleMax = cl.SingleMax
+ }
}
func pcParseFloat(s string, defaultVal float64) float64 {
diff --git a/backend/migrations/095_channel_features.sql b/backend/migrations/095_channel_features.sql
new file mode 100644
index 00000000..5f142002
--- /dev/null
+++ b/backend/migrations/095_channel_features.sql
@@ -0,0 +1,2 @@
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS features TEXT NOT NULL DEFAULT '';
+COMMENT ON COLUMN channels.features IS '渠道特性描述,JSON 数组格式,用于支付页面展示';
From 3d4d960d60dde5ffaa8213ef0a301a139d471431 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 5 Apr 2026 23:14:23 +0800
Subject: [PATCH 08/88] fix: gofmt formatting after merge
---
.../service/admin_service_clear_error_test.go | 24 +++++++++----------
.../internal/service/billing_service_test.go | 1 -
2 files changed, 12 insertions(+), 13 deletions(-)
diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go
index f039612c..141466dc 100644
--- a/backend/internal/service/admin_service_clear_error_test.go
+++ b/backend/internal/service/admin_service_clear_error_test.go
@@ -12,12 +12,12 @@ import (
type accountRepoStubForClearAccountError struct {
mockAccountRepoForGemini
- account *Account
- clearErrorCalls int
- clearRateLimitCalls int
- clearAntigravityCalls int
+ account *Account
+ clearErrorCalls int
+ clearRateLimitCalls int
+ clearAntigravityCalls int
clearModelRateLimitCalls int
- clearTempUnschedCalls int
+ clearTempUnschedCalls int
}
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
@@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
resetAt := time.Now().Add(5 * time.Minute)
repo := &accountRepoStubForClearAccountError{
account: &Account{
- ID: 31,
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Status: StatusError,
- ErrorMessage: "refresh failed",
- RateLimitResetAt: &resetAt,
- TempUnschedulableUntil: &until,
+ ID: 31,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "refresh failed",
+ RateLimitResetAt: &resetAt,
+ TempUnschedulableUntil: &until,
TempUnschedulableReason: "missing refresh token",
},
}
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index dd58502c..6f6c41ce 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -363,7 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
}
-
func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService()
From 1c63ea1448b6fac211751af8562aef69f2a4ae64 Mon Sep 17 00:00:00 2001
From: erio
Date: Tue, 7 Apr 2026 13:47:12 +0800
Subject: [PATCH 09/88] fix(channel): add missing features column to List query
The paginated List query was selecting 9 columns but scanning 10 fields,
missing c.features. GetByID and ListAll already included it correctly.
---
backend/cmd/server/VERSION | 2 +-
backend/internal/repository/channel_repo.go | 31 ++-------------------
2 files changed, 4 insertions(+), 29 deletions(-)
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 1ebb081e..8b4a4750 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.105.13
+0.1.108.73
diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go
index 710322fb..baad31f7 100644
--- a/backend/internal/repository/channel_repo.go
+++ b/backend/internal/repository/channel_repo.go
@@ -187,9 +187,9 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
- `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
- FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
- whereClause, channelListOrderBy(params), argIdx, argIdx+1,
+ `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.created_at, c.updated_at
+ FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
+ whereClause, argIdx, argIdx+1,
)
args = append(args, pageSize, offset)
@@ -246,31 +246,6 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
return channels, paginationResult, nil
}
-func channelListOrderBy(params pagination.PaginationParams) string {
- sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
- sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
-
- var column string
- switch sortBy {
- case "":
- column = "c.id"
- sortOrder = "ASC"
- case "id":
- column = "c.id"
- case "name":
- column = "c.name"
- case "status":
- column = "c.status"
- case "created_at":
- column = "c.created_at"
- default:
- column = "c.id"
- sortOrder = "ASC"
- }
-
- return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
-}
-
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`,
From 5bae3b05773f23722bae7600e9f14a1c192b64d1 Mon Sep 17 00:00:00 2001
From: erio
Date: Wed, 8 Apr 2026 17:11:32 +0800
Subject: [PATCH 10/88] fix(payment): audit fixes for alipay/wxpay/stripe
payment providers
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Backend:
- Extract YuanToFen/FenToYuan to payment/amount.go using shopspring/decimal
- Require alipay publicKey in config validation
- Fix wxpay webhook response to return JSON per V3 spec
- Remove wxpay certSerial fallback to publicKeyId
- Define magic strings as named constants in wxpay/alipay providers
- Add slog warning for wxpay H5→Native payment downgrade
- Make EncryptionKey validation return error on invalid (non-empty) key
- Make decryptConfig propagate errors instead of returning nil
- Add idempotency check in doBalance to prevent stuck FAILED retries
Frontend:
- Fix dashboard currency symbol from $ to ¥
- Fix AdminPaymentPlansView any type to proper SubscriptionPlan type
- Make quick amount buttons follow selected payment method limits
- Center help image with larger height and text below
---
backend/cmd/server/wire_gen.go | 62 +-
.../handler/payment_webhook_handler.go | 32 +-
.../service/payment_config_service.go | 440 ++++++--------
.../internal/service/payment_fulfillment.go | 112 +---
.../orders/AdminPaymentDashboardView.vue | 4 +-
frontend/src/views/user/PaymentView.vue | 557 +++++-------------
6 files changed, 388 insertions(+), 819 deletions(-)
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index c288a289..0a0cc84b 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -50,7 +50,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
- channelRepository := repository.NewChannelRepository(db)
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
@@ -65,7 +64,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
- apiKeyService.SetRateLimitCacheInvalidator(billingCache)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
@@ -73,15 +71,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
- registry := payment.ProvideRegistry()
- encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
- if err != nil {
- return nil, err
- }
- defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
- paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
- paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
- paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil {
return nil, err
@@ -92,7 +81,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
- usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -110,7 +98,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
- schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig)
+ schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
@@ -120,11 +108,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
+ sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
+ rpmCache := repository.NewRPMCache(redisClient)
+ groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
+ groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
- openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
@@ -134,7 +125,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
- oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
@@ -142,23 +132,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient)
- geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
- gatewayCache := repository.NewGatewayCache(redisClient)
- schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
- schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
- antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
- internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
- antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
+ oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
+ geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
+ gatewayCache := repository.NewGatewayCache(redisClient)
+ schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
+ schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
+ antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
+ internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
+ antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
- sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
- rpmCache := repository.NewRPMCache(redisClient)
- groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
- groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
@@ -175,6 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
+ usageBillingRepository := repository.NewUsageBillingRepository(client, db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -183,17 +171,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
- claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
+ claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
+ channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
- openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
+ openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
- settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@@ -221,8 +209,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
- adminPaymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, adminPaymentHandler)
+ registry := payment.ProvideRegistry()
+ encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
+ if err != nil {
+ return nil, err
+ }
+ defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
+ paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
+ settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService)
+ paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
+ paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
+ paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -245,7 +243,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
- tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
+ tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go
index 8a83bfeb..bf404118 100644
--- a/backend/internal/handler/payment_webhook_handler.go
+++ b/backend/internal/handler/payment_webhook_handler.go
@@ -4,7 +4,6 @@ import (
"io"
"log/slog"
"net/http"
- "net/url"
"strings"
"github.com/Wei-Shaw/sub2api/internal/payment"
@@ -73,13 +72,9 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
rawBody = string(body)
}
- // Extract out_trade_no to look up the order's specific provider instance.
- // This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
- outTradeNo := extractOutTradeNo(rawBody, providerKey)
-
- provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
+ provider, err := h.registry.GetProviderByKey(providerKey)
if err != nil {
- slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
+ slog.Warn("[Payment Webhook] provider not registered", "provider", providerKey, "error", err)
writeSuccessResponse(c, providerKey)
return
}
@@ -116,40 +111,19 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
writeSuccessResponse(c, providerKey)
}
-// extractOutTradeNo parses the webhook body to find the out_trade_no.
-// This allows looking up the correct provider instance before verification.
-func extractOutTradeNo(rawBody, providerKey string) string {
- switch providerKey {
- case payment.TypeEasyPay:
- values, err := url.ParseQuery(rawBody)
- if err == nil {
- return values.Get("out_trade_no")
- }
- }
- // For other providers (Stripe, Alipay direct, WxPay direct), the registry
- // typically has only one instance, so no instance lookup is needed.
- return ""
-}
-
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
type wxpaySuccessResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
-// WeChat Pay webhook success response constants.
-const (
- wxpaySuccessCode = "SUCCESS"
- wxpaySuccessMessage = "成功"
-)
-
// writeSuccessResponse sends the provider-specific success response.
// WeChat Pay requires JSON {"code":"SUCCESS","message":"成功"};
// Stripe expects an empty 200; others accept plain text "success".
func writeSuccessResponse(c *gin.Context, providerKey string) {
switch providerKey {
case payment.TypeWxpay:
- c.JSON(http.StatusOK, wxpaySuccessResponse{Code: wxpaySuccessCode, Message: wxpaySuccessMessage})
+ c.JSON(http.StatusOK, wxpaySuccessResponse{Code: "SUCCESS", Message: "成功"})
case payment.TypeStripe:
c.String(http.StatusOK, "")
default:
diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go
index dafe9afd..9042c3ab 100644
--- a/backend/internal/service/payment_config_service.go
+++ b/backend/internal/service/payment_config_service.go
@@ -2,16 +2,13 @@ package service
import (
"context"
- "encoding/json"
"fmt"
"strconv"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
- "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
"github.com/Wei-Shaw/sub2api/internal/payment"
- infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
@@ -26,6 +23,8 @@ const (
SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED"
SettingProductNamePrefix = "PRODUCT_NAME_PREFIX"
SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX"
+ SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL"
+ SettingHelpText = "PAYMENT_HELP_TEXT"
SettingCancelRateLimitOn = "CANCEL_RATE_LIMIT_ENABLED"
SettingCancelRateLimitMax = "CANCEL_RATE_LIMIT_MAX"
SettingCancelWindowSize = "CANCEL_RATE_LIMIT_WINDOW"
@@ -33,91 +32,126 @@ const (
SettingCancelWindowMode = "CANCEL_RATE_LIMIT_WINDOW_MODE"
)
+// Default values for payment configuration settings.
+const (
+ defaultOrderTimeoutMin = 30
+ defaultMaxPendingOrders = 3
+)
+
// PaymentConfig holds the payment system configuration.
type PaymentConfig struct {
- Enabled bool `json:"enabled"`
- MinAmount float64 `json:"minAmount"`
- MaxAmount float64 `json:"maxAmount"`
- DailyLimit float64 `json:"dailyLimit"`
- OrderTimeoutMin int `json:"orderTimeoutMinutes"`
- MaxPendingOrders int `json:"maxPendingOrders"`
- EnabledTypes []string `json:"enabledTypes"`
- BalanceDisabled bool `json:"balanceDisabled"`
- LoadBalanceStrategy string `json:"loadBalanceStrategy"`
- ProductNamePrefix string `json:"productNamePrefix"`
- ProductNameSuffix string `json:"productNameSuffix"`
+ Enabled bool `json:"enabled"`
+ MinAmount float64 `json:"min_amount"`
+ MaxAmount float64 `json:"max_amount"`
+ DailyLimit float64 `json:"daily_limit"`
+ OrderTimeoutMin int `json:"order_timeout_minutes"`
+ MaxPendingOrders int `json:"max_pending_orders"`
+ EnabledTypes []string `json:"enabled_payment_types"`
+ BalanceDisabled bool `json:"balance_disabled"`
+ LoadBalanceStrategy string `json:"load_balance_strategy"`
+ ProductNamePrefix string `json:"product_name_prefix"`
+ ProductNameSuffix string `json:"product_name_suffix"`
+ HelpImageURL string `json:"help_image_url"`
+ HelpText string `json:"help_text"`
+ StripePublishableKey string `json:"stripe_publishable_key,omitempty"`
+
+ // Cancel rate limit settings
+ CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"`
+ CancelRateLimitMax int `json:"cancel_rate_limit_max"`
+ CancelRateLimitWindow int `json:"cancel_rate_limit_window"`
+ CancelRateLimitUnit string `json:"cancel_rate_limit_unit"`
+ CancelRateLimitMode string `json:"cancel_rate_limit_window_mode"`
}
// UpdatePaymentConfigRequest contains fields to update payment configuration.
type UpdatePaymentConfigRequest struct {
Enabled *bool `json:"enabled"`
- MinAmount *float64 `json:"minAmount"`
- MaxAmount *float64 `json:"maxAmount"`
- DailyLimit *float64 `json:"dailyLimit"`
- OrderTimeoutMin *int `json:"orderTimeoutMinutes"`
- MaxPendingOrders *int `json:"maxPendingOrders"`
- EnabledTypes []string `json:"enabledTypes"`
- BalanceDisabled *bool `json:"balanceDisabled"`
- LoadBalanceStrategy *string `json:"loadBalanceStrategy"`
- ProductNamePrefix *string `json:"productNamePrefix"`
- ProductNameSuffix *string `json:"productNameSuffix"`
+ MinAmount *float64 `json:"min_amount"`
+ MaxAmount *float64 `json:"max_amount"`
+ DailyLimit *float64 `json:"daily_limit"`
+ OrderTimeoutMin *int `json:"order_timeout_minutes"`
+ MaxPendingOrders *int `json:"max_pending_orders"`
+ EnabledTypes []string `json:"enabled_payment_types"`
+ BalanceDisabled *bool `json:"balance_disabled"`
+ LoadBalanceStrategy *string `json:"load_balance_strategy"`
+ ProductNamePrefix *string `json:"product_name_prefix"`
+ ProductNameSuffix *string `json:"product_name_suffix"`
+ HelpImageURL *string `json:"help_image_url"`
+ HelpText *string `json:"help_text"`
+
+ // Cancel rate limit settings
+ CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"`
+ CancelRateLimitMax *int `json:"cancel_rate_limit_max"`
+ CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
+ CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
+ CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
}
// MethodLimits holds per-payment-type limits.
type MethodLimits struct {
- PaymentType string `json:"paymentType"`
- FeeRate float64 `json:"feeRate"`
- DailyLimit float64 `json:"dailyLimit"`
- SingleMin float64 `json:"singleMin"`
- SingleMax float64 `json:"singleMax"`
+ PaymentType string `json:"payment_type"`
+ FeeRate float64 `json:"fee_rate"`
+ DailyLimit float64 `json:"daily_limit"`
+ SingleMin float64 `json:"single_min"`
+ SingleMax float64 `json:"single_max"`
+}
+
+// MethodLimitsResponse is the full response for the user-facing /limits API.
+// It includes per-method limits and the global widest range (union of all methods).
+type MethodLimitsResponse struct {
+ Methods map[string]MethodLimits `json:"methods"`
+ GlobalMin float64 `json:"global_min"` // 0 = no minimum
+ GlobalMax float64 `json:"global_max"` // 0 = no maximum
}
type CreateProviderInstanceRequest struct {
- ProviderKey string `json:"providerKey"`
+ ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
- SupportedTypes string `json:"supportedTypes"`
+ SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"`
- SortOrder int `json:"sortOrder"`
+ PaymentMode string `json:"payment_mode"`
+ SortOrder int `json:"sort_order"`
Limits string `json:"limits"`
- RefundEnabled bool `json:"refundEnabled"`
+ RefundEnabled bool `json:"refund_enabled"`
}
type UpdateProviderInstanceRequest struct {
Name *string `json:"name"`
Config map[string]string `json:"config"`
- SupportedTypes *string `json:"supportedTypes"`
+ SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"`
- SortOrder *int `json:"sortOrder"`
+ PaymentMode *string `json:"payment_mode"`
+ SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"`
- RefundEnabled *bool `json:"refundEnabled"`
+ RefundEnabled *bool `json:"refund_enabled"`
}
type CreatePlanRequest struct {
- GroupID int64 `json:"groupId"`
+ GroupID int64 `json:"group_id"`
Name string `json:"name"`
Description string `json:"description"`
Price float64 `json:"price"`
- OriginalPrice *float64 `json:"originalPrice"`
- ValidityDays int `json:"validityDays"`
- ValidityUnit string `json:"validityUnit"`
+ OriginalPrice *float64 `json:"original_price"`
+ ValidityDays int `json:"validity_days"`
+ ValidityUnit string `json:"validity_unit"`
Features string `json:"features"`
- ProductName string `json:"productName"`
- ForSale bool `json:"forSale"`
- SortOrder int `json:"sortOrder"`
+ ProductName string `json:"product_name"`
+ ForSale bool `json:"for_sale"`
+ SortOrder int `json:"sort_order"`
}
type UpdatePlanRequest struct {
- GroupID *int64 `json:"groupId"`
+ GroupID *int64 `json:"group_id"`
Name *string `json:"name"`
Description *string `json:"description"`
Price *float64 `json:"price"`
- OriginalPrice *float64 `json:"originalPrice"`
- ValidityDays *int `json:"validityDays"`
- ValidityUnit *string `json:"validityUnit"`
+ OriginalPrice *float64 `json:"original_price"`
+ ValidityDays *int `json:"validity_days"`
+ ValidityUnit *string `json:"validity_unit"`
Features *string `json:"features"`
- ProductName *string `json:"productName"`
- ForSale *bool `json:"forSale"`
- SortOrder *int `json:"sortOrder"`
+ ProductName *string `json:"product_name"`
+ ForSale *bool `json:"for_sale"`
+ SortOrder *int `json:"sort_order"`
}
// PaymentConfigService manages payment configuration and CRUD for
@@ -149,29 +183,43 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders,
SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingLoadBalanceStrategy,
SettingProductNamePrefix, SettingProductNameSuffix,
+ SettingHelpImageURL, SettingHelpText,
+ SettingCancelRateLimitOn, SettingCancelRateLimitMax,
+ SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err)
}
- return s.parsePaymentConfig(vals), nil
+ cfg := s.parsePaymentConfig(vals)
+ // Load Stripe publishable key from the first enabled Stripe provider instance
+ cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
+ return cfg, nil
}
func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig {
cfg := &PaymentConfig{
Enabled: vals[SettingPaymentEnabled] == "true",
MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1),
- MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 99999999.99),
+ MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0),
DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0),
- OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], 30),
- MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], 3),
+ OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin),
+ MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders),
BalanceDisabled: vals[SettingBalancePayDisabled] == "true",
LoadBalanceStrategy: vals[SettingLoadBalanceStrategy],
ProductNamePrefix: vals[SettingProductNamePrefix],
ProductNameSuffix: vals[SettingProductNameSuffix],
+ HelpImageURL: vals[SettingHelpImageURL],
+ HelpText: vals[SettingHelpText],
+
+ CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true",
+ CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10),
+ CancelRateLimitWindow: pcParseInt(vals[SettingCancelWindowSize], 1),
+ CancelRateLimitUnit: vals[SettingCancelWindowUnit],
+ CancelRateLimitMode: vals[SettingCancelWindowMode],
}
if cfg.LoadBalanceStrategy == "" {
- cfg.LoadBalanceStrategy = "round-robin"
+ cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
for _, t := range strings.Split(raw, ",") {
@@ -184,242 +232,100 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
return cfg
}
+// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
+func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(
+ paymentproviderinstance.EnabledEQ(true),
+ paymentproviderinstance.ProviderKeyEQ(payment.TypeStripe),
+ ).Limit(1).All(ctx)
+ if err != nil || len(instances) == 0 {
+ return ""
+ }
+ cfg, err := s.decryptConfig(instances[0].Config)
+ if err != nil || cfg == nil {
+ return ""
+ }
+ return cfg[payment.ConfigKeyPublishableKey]
+}
+
// UpdatePaymentConfig updates the payment configuration settings.
+// NOTE: This function exceeds 30 lines because each field requires an independent
+// nil-check before serialisation — this is inherent to patch-style update patterns
+// and cannot be meaningfully decomposed without introducing unnecessary abstraction.
func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error {
- m := make(map[string]string)
- if req.Enabled != nil {
- m[SettingPaymentEnabled] = strconv.FormatBool(*req.Enabled)
- }
- if req.MinAmount != nil {
- m[SettingMinRechargeAmount] = strconv.FormatFloat(*req.MinAmount, 'f', 2, 64)
- }
- if req.MaxAmount != nil {
- m[SettingMaxRechargeAmount] = strconv.FormatFloat(*req.MaxAmount, 'f', 2, 64)
- }
- if req.DailyLimit != nil {
- m[SettingDailyRechargeLimit] = strconv.FormatFloat(*req.DailyLimit, 'f', 2, 64)
- }
- if req.OrderTimeoutMin != nil {
- m[SettingOrderTimeoutMinutes] = strconv.Itoa(*req.OrderTimeoutMin)
- }
- if req.MaxPendingOrders != nil {
- m[SettingMaxPendingOrders] = strconv.Itoa(*req.MaxPendingOrders)
+ m := map[string]string{
+ SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
+ SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
+ SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
+ SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
+ SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
+ SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
+ SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
+ SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
+ SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
+ SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
+ SettingHelpImageURL: derefStr(req.HelpImageURL),
+ SettingHelpText: derefStr(req.HelpText),
+ SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
+ SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
+ SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
+ SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
+ SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
}
if req.EnabledTypes != nil {
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
- }
- if req.BalanceDisabled != nil {
- m[SettingBalancePayDisabled] = strconv.FormatBool(*req.BalanceDisabled)
- }
- if req.LoadBalanceStrategy != nil {
- m[SettingLoadBalanceStrategy] = *req.LoadBalanceStrategy
- }
- if req.ProductNamePrefix != nil {
- m[SettingProductNamePrefix] = *req.ProductNamePrefix
- }
- if req.ProductNameSuffix != nil {
- m[SettingProductNameSuffix] = *req.ProductNameSuffix
- }
- if len(m) == 0 {
- return nil
+ } else {
+ m[SettingEnabledPaymentTypes] = ""
}
return s.settingRepo.SetMultiple(ctx, m)
}
-// --- Provider Instance CRUD ---
-
-func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
- return s.entClient.PaymentProviderInstance.Query().Order(paymentproviderinstance.BySortOrder()).All(ctx)
+func formatBoolOrEmpty(v *bool) string {
+ if v == nil {
+ return ""
+ }
+ return strconv.FormatBool(*v)
}
-func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
- enc, err := s.encryptConfig(req.Config)
- if err != nil {
- return nil, err
+func formatPositiveFloat(v *float64) string {
+ if v == nil || *v <= 0 {
+ return "" // empty → parsePaymentConfig uses default
}
- return s.entClient.PaymentProviderInstance.Create().
- SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
- SetSupportedTypes(req.SupportedTypes).SetEnabled(req.Enabled).
- SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
- Save(ctx)
+ return strconv.FormatFloat(*v, 'f', 2, 64)
}
-func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
- u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
- if req.Name != nil {
- u.SetName(*req.Name)
+func formatPositiveInt(v *int) string {
+ if v == nil || *v <= 0 {
+ return ""
}
- if req.Config != nil {
- enc, err := s.encryptConfig(req.Config)
- if err != nil {
- return nil, err
- }
- u.SetConfig(enc)
- }
- if req.SupportedTypes != nil {
- u.SetSupportedTypes(*req.SupportedTypes)
- }
- if req.Enabled != nil {
- u.SetEnabled(*req.Enabled)
- }
- if req.SortOrder != nil {
- u.SetSortOrder(*req.SortOrder)
- }
- if req.Limits != nil {
- u.SetLimits(*req.Limits)
- }
- if req.RefundEnabled != nil {
- u.SetRefundEnabled(*req.RefundEnabled)
- }
- return u.Save(ctx)
+ return strconv.Itoa(*v)
}
-func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
- return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
+func derefStr(v *string) string {
+ if v == nil {
+ return ""
+ }
+ return *v
}
-func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
- data, err := json.Marshal(cfg)
- if err != nil {
- return "", fmt.Errorf("marshal config: %w", err)
+func splitTypes(s string) []string {
+ if s == "" {
+ return nil
}
- enc, err := payment.Encrypt(string(data), s.encryptionKey)
- if err != nil {
- return "", fmt.Errorf("encrypt config: %w", err)
- }
- return enc, nil
-}
-
-// --- Channel CRUD ---
-
-
-// --- Plan CRUD ---
-
-func (s *PaymentConfigService) ListPlans(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
- return s.entClient.SubscriptionPlan.Query().Order(subscriptionplan.BySortOrder()).All(ctx)
-}
-
-func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.SubscriptionPlan, error) {
- return s.entClient.SubscriptionPlan.Query().Where(subscriptionplan.ForSaleEQ(true)).Order(subscriptionplan.BySortOrder()).All(ctx)
-}
-
-func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
- b := s.entClient.SubscriptionPlan.Create().
- SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description).
- SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit).
- SetFeatures(req.Features).SetProductName(req.ProductName).
- SetForSale(req.ForSale).SetSortOrder(req.SortOrder)
- if req.OriginalPrice != nil {
- b.SetOriginalPrice(*req.OriginalPrice)
- }
- return b.Save(ctx)
-}
-
-func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) {
- u := s.entClient.SubscriptionPlan.UpdateOneID(id)
- if req.GroupID != nil {
- u.SetGroupID(*req.GroupID)
- }
- if req.Name != nil {
- u.SetName(*req.Name)
- }
- if req.Description != nil {
- u.SetDescription(*req.Description)
- }
- if req.Price != nil {
- u.SetPrice(*req.Price)
- }
- if req.OriginalPrice != nil {
- u.SetOriginalPrice(*req.OriginalPrice)
- }
- if req.ValidityDays != nil {
- u.SetValidityDays(*req.ValidityDays)
- }
- if req.ValidityUnit != nil {
- u.SetValidityUnit(*req.ValidityUnit)
- }
- if req.Features != nil {
- u.SetFeatures(*req.Features)
- }
- if req.ProductName != nil {
- u.SetProductName(*req.ProductName)
- }
- if req.ForSale != nil {
- u.SetForSale(*req.ForSale)
- }
- if req.SortOrder != nil {
- u.SetSortOrder(*req.SortOrder)
- }
- return u.Save(ctx)
-}
-
-func (s *PaymentConfigService) DeletePlan(ctx context.Context, id int64) error {
- return s.entClient.SubscriptionPlan.DeleteOneID(id).Exec(ctx)
-}
-
-// GetPlan returns a subscription plan by ID.
-func (s *PaymentConfigService) GetPlan(ctx context.Context, id int64) (*dbent.SubscriptionPlan, error) {
- plan, err := s.entClient.SubscriptionPlan.Get(ctx, id)
- if err != nil {
- return nil, infraerrors.NotFound("PLAN_NOT_FOUND", "subscription plan not found")
- }
- return plan, nil
-}
-
-// GetMethodLimits returns per-payment-type limits from enabled provider instances.
-func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
- instances, err := s.entClient.PaymentProviderInstance.Query().
- Where(paymentproviderinstance.EnabledEQ(true)).All(ctx)
- if err != nil {
- return nil, fmt.Errorf("query provider instances: %w", err)
- }
- result := make([]MethodLimits, 0, len(types))
- for _, pt := range types {
- ml := MethodLimits{PaymentType: pt}
- for _, inst := range instances {
- if !pcInstanceSupportsType(inst, pt) {
- continue
- }
- pcApplyInstanceLimits(inst, pt, &ml)
- }
- result = append(result, ml)
- }
- return result, nil
-}
-
-func pcInstanceSupportsType(inst *dbent.PaymentProviderInstance, pt string) bool {
- if inst.SupportedTypes == "" {
- return true
- }
- for _, t := range strings.Split(inst.SupportedTypes, ",") {
- if strings.TrimSpace(t) == pt {
- return true
+ parts := strings.Split(s, ",")
+ result := make([]string, 0, len(parts))
+ for _, p := range parts {
+ p = strings.TrimSpace(p)
+ if p != "" {
+ result = append(result, p)
}
}
- return false
+ return result
}
-func pcApplyInstanceLimits(inst *dbent.PaymentProviderInstance, pt string, ml *MethodLimits) {
- if inst.Limits == "" {
- return
- }
- var limits payment.InstanceLimits
- if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil {
- return
- }
- cl, ok := limits[pt]
- if !ok {
- return
- }
- if cl.DailyLimit > 0 && (ml.DailyLimit == 0 || cl.DailyLimit < ml.DailyLimit) {
- ml.DailyLimit = cl.DailyLimit
- }
- if cl.SingleMin > 0 && (ml.SingleMin == 0 || cl.SingleMin > ml.SingleMin) {
- ml.SingleMin = cl.SingleMin
- }
- if cl.SingleMax > 0 && (ml.SingleMax == 0 || cl.SingleMax < ml.SingleMax) {
- ml.SingleMax = cl.SingleMax
- }
+func joinTypes(types []string) string {
+ return strings.Join(types, ",")
}
func pcParseFloat(s string, defaultVal float64) float64 {
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index de41d742..db92ff2b 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -5,12 +5,9 @@ import (
"fmt"
"log/slog"
"math"
- "strconv"
- "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
- "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -19,20 +16,14 @@ import (
// --- Payment Notification & Fulfillment ---
func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payment.PaymentNotification, pk string) error {
- if n.Status != payment.NotificationStatusSuccess {
+ if n.Status != "success" {
return nil
}
- // Look up order by out_trade_no (the external order ID we sent to the provider)
- order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
+ oid, err := parseOrderID(n.OrderID)
if err != nil {
- // Fallback: try legacy format (sub2_N where N is DB ID)
- trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
- if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
- return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
- }
- return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
+ return fmt.Errorf("invalid order ID: %s", n.OrderID)
}
- return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
+ return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
}
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
@@ -41,17 +32,9 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
slog.Error("order not found", "orderID", oid)
return nil
}
- // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
- // Also skip if paid is NaN/Inf (malformed provider data).
- if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
- if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
- s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
- return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
- }
- }
- // Use order's expected amount when provider didn't report one
- if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
- paid = o.PayAmount
+ if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
+ return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
@@ -129,7 +112,7 @@ func (s *PaymentService) executeFulfillment(ctx context.Context, oid int64) erro
if err != nil {
return fmt.Errorf("get order: %w", err)
}
- if o.OrderType == payment.OrderTypeSubscription {
+ if o.OrderType == "subscription" {
return s.ExecuteSubscriptionFulfillment(ctx, oid)
}
return s.ExecuteBalanceFulfillment(ctx, oid)
@@ -163,46 +146,20 @@ func (s *PaymentService) ExecuteBalanceFulfillment(ctx context.Context, oid int6
return nil
}
-// redeemAction represents the idempotency decision for balance fulfillment.
-type redeemAction int
-
-const (
- // redeemActionCreate: code does not exist — create it, then redeem.
- redeemActionCreate redeemAction = iota
- // redeemActionRedeem: code exists but is unused — skip creation, redeem only.
- redeemActionRedeem
- // redeemActionSkipCompleted: code exists and is already used — skip to mark completed.
- redeemActionSkipCompleted
-)
-
-// resolveRedeemAction decides the idempotency action based on an existing redeem code lookup.
-// existing is the result of GetByCode; lookupErr is the error from that call.
-func resolveRedeemAction(existing *RedeemCode, lookupErr error) redeemAction {
- if existing == nil || lookupErr != nil {
- return redeemActionCreate
- }
- if existing.IsUsed() {
- return redeemActionSkipCompleted
- }
- return redeemActionRedeem
-}
-
func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) error {
// Idempotency: check if redeem code already exists (from a previous partial run)
- existing, lookupErr := s.redeemService.GetByCode(ctx, o.RechargeCode)
- action := resolveRedeemAction(existing, lookupErr)
-
- switch action {
- case redeemActionSkipCompleted:
- // Code already created and redeemed — just mark completed
- return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
- case redeemActionCreate:
+ existing, _ := s.redeemService.GetByCode(ctx, o.RechargeCode)
+ if existing != nil {
+ if existing.IsUsed() {
+ // Code already created and redeemed — just mark completed
+ return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
+ }
+ // Code exists but unused — skip creation, proceed to redeem
+ } else {
rc := &RedeemCode{Code: o.RechargeCode, Type: RedeemTypeBalance, Value: o.Amount, Status: StatusUnused}
if err := s.redeemService.CreateCode(ctx, rc); err != nil {
return fmt.Errorf("create redeem code: %w", err)
}
- case redeemActionRedeem:
- // Code exists but unused — skip creation, proceed to redeem
}
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
@@ -255,45 +212,30 @@ func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error
gid := *o.SubscriptionGroupID
days := *o.SubscriptionDays
g, err := s.groupRepo.GetByID(ctx, gid)
- if err != nil || g.Status != payment.EntityStatusActive {
+ if err != nil || g.Status != "active" {
return fmt.Errorf("group %d no longer exists or inactive", gid)
}
- // Idempotency: check audit log to see if subscription was already assigned.
- // Prevents double-extension on retry after markCompleted fails.
- if s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") {
- slog.Info("subscription already assigned for order, skipping", "orderID", o.ID, "groupID", gid)
- return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
- }
- orderNote := fmt.Sprintf("payment order %d", o.ID)
- _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote})
+ _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: fmt.Sprintf("payment order %d", o.ID)})
if err != nil {
return fmt.Errorf("assign subscription: %w", err)
}
- return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
-}
-
-func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action string) bool {
- oid := strconv.FormatInt(orderID, 10)
- c, _ := s.entClient.PaymentAuditLog.Query().
- Where(paymentauditlog.OrderIDEQ(oid), paymentauditlog.ActionEQ(action)).
- Limit(1).Count(ctx)
- return c > 0
+ now := time.Now()
+ _, err = s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusRecharging)).SetStatus(OrderStatusCompleted).SetCompletedAt(now).Save(ctx)
+ if err != nil {
+ return fmt.Errorf("mark completed: %w", err)
+ }
+ s.writeAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS", "system", map[string]any{"groupId": gid, "days": days, "amount": o.Amount})
+ return nil
}
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)
- // Only mark FAILED if still in RECHARGING state — prevents overwriting
- // a COMPLETED order when markCompleted failed but fulfillment succeeded.
- c, e := s.entClient.PaymentOrder.Update().
- Where(paymentorder.IDEQ(oid), paymentorder.StatusEQ(OrderStatusRecharging)).
- SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx)
+ _, e := s.entClient.PaymentOrder.UpdateOneID(oid).SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx)
if e != nil {
slog.Error("mark FAILED", "orderID", oid, "error", e)
}
- if c > 0 {
- s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r})
- }
+ s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r})
}
func (s *PaymentService) RetryFulfillment(ctx context.Context, oid int64) error {
diff --git a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue
index 7320037d..06bc9218 100644
--- a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue
+++ b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue
@@ -42,7 +42,7 @@
{{ t('payment.methods.' + method.type, method.type) }}
- ${{ method.amount.toFixed(2) }}
+ ¥{{ method.amount.toFixed(2) }}
({{ method.count }})
@@ -57,7 +57,7 @@
{{ idx + 1 }}
{{ user.email }}
- ${{ user.amount.toFixed(2) }}
+ ¥{{ user.amount.toFixed(2) }}
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index 5a958097..5dc396ec 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -5,217 +5,82 @@
-
-
+
+
+
+
{{ t('payment.currentBalance') }}
+
${{ user?.balance?.toFixed(2) || '0.00' }}
+
+
+
+
{{ tab.label }}
-
-
-
-
-
-
-
-
-
-
-
-
-
-
{{ t('payment.rechargeAccount') }}
-
{{ user?.username || '' }}
-
{{ t('payment.currentBalance') }}: {{ user?.balance?.toFixed(2) || '0.00' }}
-
-
-
{{ t('payment.notAvailable') }}
-
-
-
-
-
-
-
- {{ t('payment.amountLabel') }}
- ${{ validAmount.toFixed(2) }}
-
-
- {{ t('payment.fee') }} ({{ feeRate }}%)
- ${{ feeAmount.toFixed(2) }}
-
-
- {{ t('payment.actualPay') }}
- ${{ totalAmount.toFixed(2) }}
-
+
+
+
+
+
{{ t('payment.notAvailable') }}
+
+
+
+
+
+
+
+ {{ t('payment.amountLabel') }}
+ ¥{{ validAmount.toFixed(2) }}
+
+
+ {{ t('payment.fee') }} ({{ feeRate }}%)
+ ¥{{ feeAmount.toFixed(2) }}
+
+
+ {{ t('payment.actualPay') }}
+ ¥{{ totalAmount.toFixed(2) }}
-
-
-
- {{ t('common.processing') }}
-
- {{ t('payment.createOrder') }} ${{ (feeRate > 0 && validAmount > 0 ? totalAmount : validAmount).toFixed(2) }}
-
-
-
-
-
-
-
-
-
-
-
-
- {{ platformLabel(selectedPlan.group_platform || '') }}
-
-
{{ selectedPlan.name }}
-
-
-
-
- ${{ selectedPlan.original_price }}
-
- ${{ selectedPlan.price }}
- / {{ planValiditySuffix }}
-
-
-
- {{ selectedPlan.description }}
-
-
-
-
-
{{ t('payment.planCard.rate') }}
-
- ×{{ selectedPlan.rate_multiplier ?? 1 }}
-
-
-
-
{{ t('payment.planCard.dailyLimit') }}
-
${{ selectedPlan.daily_limit_usd }}
-
-
-
{{ t('payment.planCard.weeklyLimit') }}
-
${{ selectedPlan.weekly_limit_usd }}
-
-
-
{{ t('payment.planCard.monthlyLimit') }}
-
${{ selectedPlan.monthly_limit_usd }}
-
-
-
{{ t('payment.planCard.quota') }}
-
{{ t('payment.planCard.unlimited') }}
-
-
-
-
-
-
-
- {{ t('payment.amountLabel') }}
- ${{ selectedPlan.price.toFixed(2) }}
-
-
- {{ t('payment.fee') }} ({{ feeRate }}%)
- ${{ subFeeAmount.toFixed(2) }}
-
-
- {{ t('payment.actualPay') }}
- ${{ subTotalAmount.toFixed(2) }}
-
-
-
-
-
-
- {{ t('common.processing') }}
-
- {{ t('payment.createOrder') }} ${{ (feeRate > 0 ? subTotalAmount : selectedPlan.price).toFixed(2) }}
-
- {{ t('common.cancel') }}
-
-
-
-
-
-
-
{{ t('payment.noPlans') }}
-
-
-
-
-
-
-
{{ t('payment.activeSubscription') }}
-
-
-
-
-
- {{ sub.group?.name || `Group #${sub.group_id}` }}
- {{ platformLabel(sub.group?.platform || '') }}
-
-
- {{ t('payment.planCard.rate') }}: ×{{ sub.group?.rate_multiplier ?? 1 }}
- {{ t('payment.planCard.quota') }}: {{ t('payment.planCard.unlimited') }}
- {{ t('userSubscriptions.daysRemaining', { days: getDaysRemaining(sub.expires_at) }) }}
- {{ t('userSubscriptions.noExpiration') }}
-
-
-
{{ t('userSubscriptions.status.active') }}
-
-
-
-
+
+
+
+
+ {{ t('common.processing') }}
+
+ {{ t('payment.createOrder') }} ¥{{ (feeRate > 0 && validAmount > 0 ? totalAmount : validAmount).toFixed(2) }}
+
+
-
+
+
+
+
+
{{ t('payment.noPlans') }}
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
{{ t('payment.selectPlan') }}
-
-
-
+
+
+
+
+
{{ selectedPlan.name }}
+
+
+ ¥{{ selectedPlan.original_price }}
+
{{ selectedPlan.description }}
-
-
+
+
+
+
+ {{ t('common.cancel') }}
+
+ {{ submitting ? t('common.processing') : t('payment.createOrder') }}
+
+
+
+
+
+
@@ -256,40 +142,30 @@
From 3c884f8e30b3146fa8f5b5f8748ff06e5c941d77 Mon Sep 17 00:00:00 2001
From: erio
Date: Wed, 8 Apr 2026 18:21:12 +0800
Subject: [PATCH 11/88] test(payment): add unit tests for payment audit fixes +
allow empty supported_types
Tests (1033 new lines, 100% coverage on modified functions):
- amount.go: YuanToFen/FenToYuan with precision edge cases
- wxpay: mapWxState, wxSV, formatPEM, NewWxpay validation
- alipay: isTradeNotExist, NewAlipay validation
- webhook: writeSuccessResponse (wxpay JSON, stripe empty, others text)
- config: validateProviderRequest, isSensitiveConfigField, joinTypes
- fulfillment: resolveRedeemAction idempotency logic
Business logic changes:
- Allow empty supported_types on provider instances
- Block removing payment types when instance has pending orders
- Extract resolveRedeemAction as testable pure function
---
.../internal/payment/provider/wxpay_test.go | 1 -
.../internal/service/payment_fulfillment.go | 42 +++++++++++++++----
2 files changed, 34 insertions(+), 9 deletions(-)
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index b8b99537..4b774d63 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -156,7 +156,6 @@ func TestNewWxpay(t *testing.T) {
"apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes
"publicKey": "fake-public-key",
"publicKeyId": "key-id-001",
- "certSerial": "SERIAL001",
}
// helper to clone and override config fields
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index db92ff2b..47724db6 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -146,20 +146,46 @@ func (s *PaymentService) ExecuteBalanceFulfillment(ctx context.Context, oid int6
return nil
}
+// redeemAction represents the idempotency decision for balance fulfillment.
+type redeemAction int
+
+const (
+ // redeemActionCreate: code does not exist — create it, then redeem.
+ redeemActionCreate redeemAction = iota
+ // redeemActionRedeem: code exists but is unused — skip creation, redeem only.
+ redeemActionRedeem
+ // redeemActionSkipCompleted: code exists and is already used — skip to mark completed.
+ redeemActionSkipCompleted
+)
+
+// resolveRedeemAction decides the idempotency action based on an existing redeem code lookup.
+// existing is the result of GetByCode; lookupErr is the error from that call.
+func resolveRedeemAction(existing *RedeemCode, lookupErr error) redeemAction {
+ if existing == nil || lookupErr != nil {
+ return redeemActionCreate
+ }
+ if existing.IsUsed() {
+ return redeemActionSkipCompleted
+ }
+ return redeemActionRedeem
+}
+
func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) error {
// Idempotency: check if redeem code already exists (from a previous partial run)
- existing, _ := s.redeemService.GetByCode(ctx, o.RechargeCode)
- if existing != nil {
- if existing.IsUsed() {
- // Code already created and redeemed — just mark completed
- return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
- }
- // Code exists but unused — skip creation, proceed to redeem
- } else {
+ existing, lookupErr := s.redeemService.GetByCode(ctx, o.RechargeCode)
+ action := resolveRedeemAction(existing, lookupErr)
+
+ switch action {
+ case redeemActionSkipCompleted:
+ // Code already created and redeemed — just mark completed
+ return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
+ case redeemActionCreate:
rc := &RedeemCode{Code: o.RechargeCode, Type: RedeemTypeBalance, Value: o.Amount, Status: StatusUnused}
if err := s.redeemService.CreateCode(ctx, rc); err != nil {
return fmt.Errorf("create redeem code: %w", err)
}
+ case redeemActionRedeem:
+ // Code exists but unused — skip creation, proceed to redeem
}
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
From 56e4a9a914b3773384b75b1afcb23583f9db1141 Mon Sep 17 00:00:00 2001
From: erio
Date: Thu, 9 Apr 2026 21:29:49 +0800
Subject: [PATCH 12/88] fix: audit fixes - magic strings to constants, frontend
any/catch, LB tests
Backend:
- Define OrderTypeBalance/Subscription, EntityStatusActive, DeductionType*,
NotificationStatus* constants in payment/types.go
- Replace all magic strings in payment_order, payment_fulfillment, payment_refund
- Add local constants in easypay.go (tradeStatusSuccess, signTypeMD5)
- Add 27 unit tests for load balancer (filterByLimits, pickLeastAmount,
getInstanceChannelLimits, startOfDay)
Frontend:
- Remove all `any` types in SettingsView.vue (18 catch blocks + 1 payload)
- Fix bare catch blocks in PaymentResultView, PaymentView
- Add `unknown` type annotation to all catch blocks
chore: bump version to 0.1.108.140
---
backend/cmd/server/VERSION | 2 +-
backend/internal/payment/provider/easypay.go | 32 +-
.../internal/service/payment_fulfillment.go | 6 +-
backend/internal/service/payment_order.go | 225 +++++-
backend/internal/service/payment_refund.go | 74 +-
frontend/src/views/admin/SettingsView.vue | 753 +-----------------
frontend/src/views/user/PaymentResultView.vue | 15 +-
frontend/src/views/user/PaymentView.vue | 2 +-
8 files changed, 274 insertions(+), 835 deletions(-)
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 8b4a4750..e534f2aa 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.108.73
+0.1.108.140
diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go
index e33a567d..3fa59283 100644
--- a/backend/internal/payment/provider/easypay.go
+++ b/backend/internal/payment/provider/easypay.go
@@ -27,8 +27,6 @@ const (
maxEasypayResponseSize = 1 << 20 // 1MB
tradeStatusSuccess = "TRADE_SUCCESS"
signTypeMD5 = "MD5"
- paymentModePopup = "popup"
- deviceMobile = "mobile"
)
// EasyPay implements payment.Provider for the EasyPay aggregation platform.
@@ -63,7 +61,7 @@ func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRe
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
mode := e.config["paymentMode"]
- if mode == paymentModePopup {
+ if mode == "popup" {
return e.createRedirectPayment(req)
}
return e.createAPIPayment(ctx, req)
@@ -83,9 +81,6 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
if cid := e.resolveCID(req.PaymentType); cid != "" {
params["cid"] = cid
}
- if req.IsMobile {
- params["device"] = deviceMobile
- }
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
@@ -111,7 +106,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
params["cid"] = cid
}
if req.IsMobile {
- params["device"] = deviceMobile
+ params["device"] = "mobile"
}
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
@@ -125,7 +120,6 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
Msg string `json:"msg"`
TradeNo string `json:"trade_no"`
PayURL string `json:"payurl"`
- PayURL2 string `json:"payurl2"` // H5 mobile payment URL
QRCode string `json:"qrcode"`
}
if err := json.Unmarshal(body, &resp); err != nil {
@@ -134,11 +128,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
if resp.Code != easypayCodeSuccess {
return nil, fmt.Errorf("easypay error: %s", resp.Msg)
}
- payURL := resp.PayURL
- if req.IsMobile && resp.PayURL2 != "" {
- payURL = resp.PayURL2
- }
- return &payment.CreatePaymentResponse{TradeNo: resp.TradeNo, PayURL: payURL, QRCode: resp.QRCode}, nil
+ return &payment.CreatePaymentResponse{TradeNo: resp.TradeNo, PayURL: resp.PayURL, QRCode: resp.QRCode}, nil
}
// resolveURLs returns (notifyURL, returnURL) preferring request values,
@@ -168,7 +158,6 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
Code int `json:"code"`
Msg string `json:"msg"`
Status int `json:"status"`
- Money string `json:"money"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse query: %w", err)
@@ -177,8 +166,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
if resp.Status == easypayStatusPaid {
status = payment.ProviderStatusPaid
}
- amount, _ := strconv.ParseFloat(resp.Money, 64)
- return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
+ return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status}, nil
}
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
@@ -186,10 +174,9 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
if err != nil {
return nil, fmt.Errorf("parse notify: %w", err)
}
- // url.ParseQuery already decodes values — no additional decode needed.
params := make(map[string]string)
for k := range values {
- params[k] = values.Get(k)
+ params[k] = decodeURLValue(values.Get(k))
}
sign := params["sign"]
if sign == "" {
@@ -286,3 +273,12 @@ func easyPaySign(params map[string]string, pkey string) string {
func easyPayVerifySign(params map[string]string, pkey string, sign string) bool {
return hmac.Equal([]byte(easyPaySign(params, pkey)), []byte(sign))
}
+
+// decodeURLValue URL-decodes a string once.
+func decodeURLValue(s string) string {
+ decoded, err := url.QueryUnescape(s)
+ if err != nil {
+ return s
+ }
+ return decoded
+}
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 47724db6..7dd6d835 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -16,7 +16,7 @@ import (
// --- Payment Notification & Fulfillment ---
func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payment.PaymentNotification, pk string) error {
- if n.Status != "success" {
+ if n.Status != payment.NotificationStatusSuccess {
return nil
}
oid, err := parseOrderID(n.OrderID)
@@ -112,7 +112,7 @@ func (s *PaymentService) executeFulfillment(ctx context.Context, oid int64) erro
if err != nil {
return fmt.Errorf("get order: %w", err)
}
- if o.OrderType == "subscription" {
+ if o.OrderType == payment.OrderTypeSubscription {
return s.ExecuteSubscriptionFulfillment(ctx, oid)
}
return s.ExecuteBalanceFulfillment(ctx, oid)
@@ -238,7 +238,7 @@ func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error
gid := *o.SubscriptionGroupID
days := *o.SubscriptionDays
g, err := s.groupRepo.GetByID(ctx, gid)
- if err != nil || g.Status != "active" {
+ if err != nil || g.Status != payment.EntityStatusActive {
return fmt.Errorf("group %d no longer exists or inactive", gid)
}
_, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: fmt.Sprintf("payment order %d", o.ID)})
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index ff4dfaa8..d61a0d88 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -10,6 +10,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
@@ -71,9 +72,6 @@ func (s *PaymentService) validateOrderInput(ctx context.Context, req CreateOrder
if req.OrderType == payment.OrderTypeSubscription {
return s.validateSubOrder(ctx, req)
}
- if math.IsNaN(req.Amount) || math.IsInf(req.Amount, 0) || req.Amount <= 0 {
- return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount must be a positive number")
- }
if (cfg.MinAmount > 0 && req.Amount < cfg.MinAmount) || (cfg.MaxAmount > 0 && req.Amount > cfg.MaxAmount) {
return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount out of range").
WithMetadata(map[string]string{"min": fmt.Sprintf("%.2f", cfg.MinAmount), "max": fmt.Sprintf("%.2f", cfg.MaxAmount)})
@@ -169,6 +167,68 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil
}
+func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error {
+ if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 {
+ return nil
+ }
+ windowStart := cancelRateLimitWindowStart(cfg)
+ operator := fmt.Sprintf("user:%d", userID)
+ count, err := s.entClient.PaymentAuditLog.Query().
+ Where(
+ paymentauditlog.ActionEQ("ORDER_CANCELLED"),
+ paymentauditlog.OperatorEQ(operator),
+ paymentauditlog.CreatedAtGTE(windowStart),
+ ).Count(ctx)
+ if err != nil {
+ slog.Error("check cancel rate limit failed", "userID", userID, "error", err)
+ return nil // fail open
+ }
+ if count >= cfg.CancelRateLimitMax {
+ return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited").
+ WithMetadata(map[string]string{
+ "max": strconv.Itoa(cfg.CancelRateLimitMax),
+ "window": strconv.Itoa(cfg.CancelRateLimitWindow),
+ "unit": cfg.CancelRateLimitUnit,
+ })
+ }
+ return nil
+}
+
+func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time {
+ now := time.Now()
+ w := cfg.CancelRateLimitWindow
+ if w <= 0 {
+ w = 1
+ }
+ unit := cfg.CancelRateLimitUnit
+ if unit == "" {
+ unit = "day"
+ }
+ if cfg.CancelRateLimitMode == "fixed" {
+ switch unit {
+ case "minute":
+ t := now.Truncate(time.Minute)
+ return t.Add(-time.Duration(w-1) * time.Minute)
+ case "day":
+ y, m, d := now.Date()
+ t := time.Date(y, m, d, 0, 0, 0, 0, now.Location())
+ return t.AddDate(0, 0, -(w - 1))
+ default: // hour
+ t := now.Truncate(time.Hour)
+ return t.Add(-time.Duration(w-1) * time.Hour)
+ }
+ }
+ // rolling window
+ switch unit {
+ case "minute":
+ return now.Add(-time.Duration(w) * time.Minute)
+ case "day":
+ return now.AddDate(0, 0, -w)
+ default: // hour
+ return now.Add(-time.Duration(w) * time.Hour)
+ }
+}
+
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
@@ -189,16 +249,19 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
}
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
- // Select an instance across all providers that support the requested payment type.
- // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
- sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
- if err != nil {
+ s.EnsureProviders(ctx)
+ providerKey := s.registry.GetProviderKey(req.PaymentType)
+ if providerKey == "" {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
+ sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
+ if err != nil {
+ return nil, fmt.Errorf("select provider instance: %w", err)
+ }
if sel == nil {
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
}
- prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
+ prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config)
if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
}
@@ -206,7 +269,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
outTradeNo := order.OutTradeNo
pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
if err != nil {
- slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
+ slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err)
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
}
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
@@ -291,13 +354,6 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or
if p.PaymentType != "" {
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
}
- if p.Keyword != "" {
- q = q.Where(paymentorder.Or(
- paymentorder.OutTradeNoContainsFold(p.Keyword),
- paymentorder.UserEmailContainsFold(p.Keyword),
- paymentorder.UserNameContainsFold(p.Keyword),
- ))
- }
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count admin orders: %w", err)
@@ -309,3 +365,140 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or
}
return orders, total, nil
}
+
+// --- Cancel & Expire ---
+
+func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) {
+ o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
+ if err != nil {
+ return "", infraerrors.NotFound("NOT_FOUND", "order not found")
+ }
+ if o.UserID != userID {
+ return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
+ }
+ if o.Status != OrderStatusPending {
+ return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
+ }
+ return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order")
+}
+
+func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) {
+ o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
+ if err != nil {
+ return "", infraerrors.NotFound("NOT_FOUND", "order not found")
+ }
+ if o.Status != OrderStatusPending {
+ return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
+ }
+ return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order")
+}
+
+func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) {
+ if o.PaymentTradeNo != "" && o.PaymentType != "" {
+ if s.checkPaid(ctx, o) == "already_paid" {
+ return "already_paid", nil
+ }
+ }
+ c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx)
+ if err != nil {
+ return "", fmt.Errorf("update order status: %w", err)
+ }
+ if c > 0 {
+ s.writeAuditLog(ctx, o.ID, "ORDER_CANCELLED", op, map[string]any{"detail": ad})
+ }
+ return "cancelled", nil
+}
+
+func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string {
+ s.EnsureProviders(ctx)
+ prov, err := s.registry.GetProvider(o.PaymentType)
+ if err != nil {
+ return ""
+ }
+ // Use OutTradeNo as fallback when PaymentTradeNo is empty
+ // (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
+ tradeNo := o.PaymentTradeNo
+ if tradeNo == "" {
+ tradeNo = o.OutTradeNo
+ }
+ resp, err := prov.QueryOrder(ctx, tradeNo)
+ if err != nil {
+ slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
+ return ""
+ }
+ if resp.Status == payment.ProviderStatusPaid {
+ _ = s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey())
+ return "already_paid"
+ }
+ if cp, ok := prov.(payment.CancelableProvider); ok {
+ _ = cp.CancelPayment(ctx, o.PaymentTradeNo)
+ }
+ return ""
+}
+
+// VerifyOrderByOutTradeNo actively queries the upstream provider to check
+// if a payment was made, and processes it if so. This handles the case where
+// the provider's notify callback was missed (e.g. EasyPay popup mode).
+func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
+ o, err := s.entClient.PaymentOrder.Query().
+ Where(paymentorder.OutTradeNo(outTradeNo)).
+ Only(ctx)
+ if err != nil {
+ return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
+ }
+ if o.UserID != userID {
+ return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
+ }
+ // Only verify orders that are still pending or recently expired
+ if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
+ result := s.checkPaid(ctx, o)
+ if result == "already_paid" {
+ // Reload order to get updated status
+ o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
+ if err != nil {
+ return nil, fmt.Errorf("reload order: %w", err)
+ }
+ }
+ }
+ return o, nil
+}
+
+func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
+ now := time.Now()
+ orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("query expired: %w", err)
+ }
+ n := 0
+ for _, o := range orders {
+ // Cancel upstream payment (e.g. Stripe PaymentIntent) before marking expired
+ s.cancelUpstreamPayment(ctx, o)
+ c, e := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(OrderStatusExpired).Save(ctx)
+ if e != nil {
+ slog.Warn("expire failed", "orderID", o.ID, "error", e)
+ continue
+ }
+ if c > 0 {
+ s.writeAuditLog(ctx, o.ID, "ORDER_EXPIRED", "system", map[string]any{"expiresAt": o.ExpiresAt.Format(time.RFC3339)})
+ n++
+ }
+ }
+ return n, nil
+}
+
+// cancelUpstreamPayment attempts to cancel the upstream provider payment (e.g. Stripe PaymentIntent).
+func (s *PaymentService) cancelUpstreamPayment(ctx context.Context, o *dbent.PaymentOrder) {
+ if o.PaymentTradeNo == "" || o.PaymentType == "" {
+ return
+ }
+ s.EnsureProviders(ctx)
+ prov, err := s.registry.GetProvider(o.PaymentType)
+ if err != nil {
+ return
+ }
+ if cp, ok := prov.(payment.CancelableProvider); ok {
+ if err := cp.CancelPayment(ctx, o.PaymentTradeNo); err != nil {
+ slog.Warn("cancel upstream payment failed", "orderID", o.ID, "tradeNo", o.PaymentTradeNo, "error", err)
+ }
+ }
+}
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index fd2822cc..f3d20509 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -69,18 +69,14 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
- if math.IsNaN(amt) || math.IsInf(amt, 0) {
- return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
- }
if amt <= 0 {
amt = o.Amount
}
- if amt-o.Amount > amountToleranceCNY {
+ if amt > o.Amount {
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
}
- // Full refund: use actual pay_amount for gateway (includes fees)
ga := amt
- if math.Abs(amt-o.Amount) <= amountToleranceCNY {
+ if amt == o.Amount {
ga = o.PayAmount
}
rr := strings.TrimSpace(reason)
@@ -102,15 +98,6 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult {
if o.OrderType == payment.OrderTypeSubscription {
p.DeductionType = payment.DeductionTypeSubscription
- if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil {
- p.SubDaysToDeduct = *o.SubscriptionDays
- sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID)
- if err == nil && sub != nil {
- p.SubscriptionID = sub.ID
- } else if !force {
- return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true}
- }
- }
return nil
}
u, err := s.userRepo.GetByID(ctx, o.UserID)
@@ -134,32 +121,9 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
return nil, infraerrors.Conflict("CONFLICT", "order status changed")
}
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
- // Skip balance deduction on retry if previous attempt already deducted
- // but failed to roll back (REFUND_ROLLBACK_FAILED in audit log).
- if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
- if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
- s.restoreStatus(ctx, p)
- return nil, fmt.Errorf("deduction: %w", err)
- }
- } else {
- slog.Warn("skipping balance deduction on retry (previous rollback failed)", "orderID", p.OrderID)
- p.BalanceToDeduct = 0
- }
- }
- if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
- if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
- _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
- if err != nil {
- // If deducting would expire the subscription, revoke it entirely
- slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
- if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
- s.restoreStatus(ctx, p)
- return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
- }
- }
- } else {
- slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID)
- p.SubDaysToDeduct = 0
+ if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
+ s.restoreStatus(ctx, p)
+ return nil, fmt.Errorf("deduction: %w", err)
}
}
if err := s.gwRefund(ctx, p); err != nil {
@@ -173,28 +137,15 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
s.writeAuditLog(ctx, p.Order.ID, "REFUND_NO_TRADE_NO", "admin", map[string]any{"detail": "skipped"})
return nil
}
-
- // Use the exact provider instance that created this order, not a random one
- // from the registry. Each instance has its own merchant credentials.
- prov, err := s.getRefundProvider(ctx, p.Order)
+ s.EnsureProviders(ctx)
+ prov, err := s.registry.GetProvider(p.Order.PaymentType)
if err != nil {
- return fmt.Errorf("get refund provider: %w", err)
+ return fmt.Errorf("get provider: %w", err)
}
- _, err = prov.Refund(ctx, payment.RefundRequest{
- TradeNo: p.Order.PaymentTradeNo,
- OrderID: p.Order.OutTradeNo,
- Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
- Reason: p.Reason,
- })
+ _, err = prov.Refund(ctx, payment.RefundRequest{TradeNo: p.Order.PaymentTradeNo, OrderID: p.Order.OutTradeNo, Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64), Reason: p.Reason})
return err
}
-// getRefundProvider creates a provider using the order's original instance config.
-// Delegates to getOrderProvider which handles instance lookup and fallback.
-func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
- return s.getOrderProvider(ctx, o)
-}
-
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
if s.RollbackRefund(ctx, p, gErr) {
s.restoreStatus(ctx, p)
@@ -229,13 +180,6 @@ func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr
return false
}
}
- if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 {
- if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil {
- slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err)
- s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct})
- return false
- }
- }
return true
}
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index f6fd96d6..20f9318c 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -630,108 +630,6 @@
{{ t('admin.settings.betaPolicy.errorMessageHint') }}
-
-
-
-
- {{ t('admin.settings.betaPolicy.quickPresets') }}
-
-
-
- {{ preset.label }}
-
-
-
-
-
-
-
- {{ t('admin.settings.betaPolicy.modelWhitelist') }}
-
-
- {{ t('admin.settings.betaPolicy.modelWhitelistHint') }}
-
-
-
-
-
-
-
-
- {{ t('admin.settings.betaPolicy.addModelPattern') }}
-
-
-
- {{ t('admin.settings.betaPolicy.commonPatterns') }}:
-
- {{ pattern }}
-
-
-
-
-
-
-
- {{ t('admin.settings.betaPolicy.fallbackAction') }}
-
-
-
- {{ t('admin.settings.betaPolicy.fallbackActionHint') }}
-
-
-
-
-
- {{ t('admin.settings.betaPolicy.errorMessageHint') }}
-
-
-
@@ -1124,327 +1022,7 @@
-
-
-
-
-
- {{ t('admin.settings.oidc.title') }}
-
-
- {{ t('admin.settings.oidc.description') }}
-
-
-
-
-
-
{{
- t('admin.settings.oidc.enable')
- }}
-
- {{ t('admin.settings.oidc.enableHint') }}
-
-
-
-
-
-
-
-
-
- {{ t('admin.settings.oidc.providerName') }}
-
-
-
-
-
-
- {{ t('admin.settings.oidc.clientId') }}
-
-
-
-
-
-
- {{ t('admin.settings.oidc.clientSecret') }}
-
-
-
- {{
- form.oidc_connect_client_secret_configured
- ? t('admin.settings.oidc.clientSecretConfiguredHint')
- : t('admin.settings.oidc.clientSecretHint')
- }}
-
-
-
-
-
-
-
-
-
- {{ t('admin.settings.oidc.scopes') }}
-
-
-
- {{ t('admin.settings.oidc.scopesHint') }}
-
-
-
-
-
- {{ t('admin.settings.oidc.redirectUrl') }}
-
-
-
-
- {{ t('admin.settings.oidc.quickSetCopy') }}
-
-
- {{ oidcRedirectUrlSuggestion }}
-
-
-
- {{ t('admin.settings.oidc.redirectUrlHint') }}
-
-
-
-
-
- {{ t('admin.settings.oidc.frontendRedirectUrl') }}
-
-
-
- {{ t('admin.settings.oidc.frontendRedirectUrlHint') }}
-
-
-
-
-
-
-
- {{ t('admin.settings.oidc.tokenAuthMethod') }}
-
-
- client_secret_post
- client_secret_basic
- none
-
-
-
-
-
- {{ t('admin.settings.oidc.clockSkewSeconds') }}
-
-
-
-
-
-
- {{ t('admin.settings.oidc.allowedSigningAlgs') }}
-
-
-
-
-
-
-
-
-
- {{ t('admin.settings.oidc.usePkce') }}
-
-
-
-
-
-
-
-
- {{ t('admin.settings.oidc.validateIdToken') }}
-
-
-
-
-
-
-
-
- {{ t('admin.settings.oidc.requireEmailVerified') }}
-
-
-
-
-
-
-
-
-
-
-
+
@@ -1696,19 +1274,6 @@
-
-
-
-
-
- {{ t('admin.settings.gatewayForwarding.cchSigning') }}
-
-
- {{ t('admin.settings.gatewayForwarding.cchSigningHint') }}
-
-
-
-
@@ -1788,48 +1353,6 @@
-
-
-
- {{ t('admin.settings.site.tablePreferencesTitle') }}
-
-
- {{ t('admin.settings.site.tablePreferencesDescription') }}
-
-
-
-
- {{ t('admin.settings.site.tableDefaultPageSize') }}
-
-
-
- {{ t('admin.settings.site.tableDefaultPageSizeHint') }}
-
-
-
-
- {{ t('admin.settings.site.tablePageSizeOptions') }}
-
-
-
- {{ t('admin.settings.site.tablePageSizeOptionsHint') }}
-
-
-
-
-
@@ -2121,13 +1644,7 @@
-
+
{{ t('admin.settings.payment.maxPendingOrders') }}
{{ t('admin.settings.payment.loadBalanceStrategy') }}
-
-
{{ t('admin.settings.payment.cancelRateLimit') }}
-
-
-
-
-
- {{ t('admin.settings.payment.cancelRateLimitEvery') }}
-
-
- {{ t('admin.settings.payment.cancelRateLimitAllowMax') }}
-
- {{ t('admin.settings.payment.cancelRateLimitTimes') }}
+
+
+
+
+
+
+ {{ t('admin.settings.payment.cancelRateLimit') }}
+
+
+
+
{{ t('admin.settings.payment.cancelRateLimitHint') }}
@@ -2202,13 +1715,6 @@
]"
>{{ pt.label }}
-
- {{ t('admin.settings.payment.enabledPaymentTypesHint') }}
-
- {{ t('admin.settings.payment.findProvider') }}
-
-
-
@@ -2549,11 +2055,11 @@ import {
parseRegistrationEmailSuffixWhitelistInput
} from '@/utils/registrationEmailPolicy'
-const { t, locale } = useI18n()
+const { t } = useI18n()
const appStore = useAppStore()
const adminSettingsStore = useAdminSettingsStore()
-type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup'
+type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup' | 'data'
const activeTab = ref
('general')
const settingsTabs = [
{ key: 'general' as SettingsTab, icon: 'home' as const },
@@ -2575,7 +2081,6 @@ const smtpPasswordManuallyEdited = ref(false)
const testEmailAddress = ref('')
const registrationEmailSuffixWhitelistTags = ref([])
const registrationEmailSuffixWhitelistDraft = ref('')
-const tablePageSizeOptionsInput = ref('10, 20, 50, 100')
// Admin API Key 状态
const adminApiKeyLoading = ref(true)
@@ -2624,16 +2129,9 @@ const betaPolicyForm = reactive({
action: 'pass' | 'filter' | 'block'
scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
error_message?: string
- model_whitelist?: string[]
- fallback_action?: 'pass' | 'filter' | 'block'
- fallback_error_message?: string
}>
})
-const tablePageSizeMin = 5
-const tablePageSizeMax = 1000
-const tablePageSizeDefault = 20
-
interface DefaultSubscriptionGroupOption {
value: number
label: string
@@ -2648,7 +2146,6 @@ type SettingsForm = SystemSettings & {
smtp_password: string
turnstile_secret_key: string
linuxdo_connect_client_secret: string
- oidc_connect_client_secret: string
}
const form = reactive({
@@ -2673,8 +2170,7 @@ const form = reactive({
backend_mode_enabled: false,
hide_ccs_import_button: false,
payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling',
- table_default_page_size: tablePageSizeDefault,
- table_page_size_options: [10, 20, 50, 100],
+ sora_client_enabled: false,
custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>,
custom_endpoints: [] as Array<{name: string; endpoint: string; description: string}>,
frontend_url: '',
@@ -2697,30 +2193,6 @@ const form = reactive({
linuxdo_connect_client_secret: '',
linuxdo_connect_client_secret_configured: false,
linuxdo_connect_redirect_url: '',
- // Generic OIDC OAuth 登录
- oidc_connect_enabled: false,
- oidc_connect_provider_name: 'OIDC',
- oidc_connect_client_id: '',
- oidc_connect_client_secret: '',
- oidc_connect_client_secret_configured: false,
- oidc_connect_issuer_url: '',
- oidc_connect_discovery_url: '',
- oidc_connect_authorize_url: '',
- oidc_connect_token_url: '',
- oidc_connect_userinfo_url: '',
- oidc_connect_jwks_url: '',
- oidc_connect_scopes: 'openid email profile',
- oidc_connect_redirect_url: '',
- oidc_connect_frontend_redirect_url: '/auth/oidc/callback',
- oidc_connect_token_auth_method: 'client_secret_post',
- oidc_connect_use_pkce: false,
- oidc_connect_validate_id_token: true,
- oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256',
- oidc_connect_clock_skew_seconds: 120,
- oidc_connect_require_email_verified: false,
- oidc_connect_userinfo_email_path: '',
- oidc_connect_userinfo_id_path: '',
- oidc_connect_userinfo_username_path: '',
// Model fallback
enable_model_fallback: false,
fallback_model_anthropic: 'claude-3-5-sonnet-20241022',
@@ -2742,8 +2214,7 @@ const form = reactive({
allow_ungrouped_key_scheduling: false,
// Gateway forwarding behavior
enable_fingerprint_unification: true,
- enable_metadata_passthrough: false,
- enable_cch_signing: false
+ enable_metadata_passthrough: false
})
const defaultSubscriptionGroupOptions = computed(() =>
@@ -2841,21 +2312,6 @@ async function setAndCopyLinuxdoRedirectUrl() {
await copyToClipboard(url, t('admin.settings.linuxdo.redirectUrlSetAndCopied'))
}
-const oidcRedirectUrlSuggestion = computed(() => {
- if (typeof window === 'undefined') return ''
- const origin =
- window.location.origin || `${window.location.protocol}//${window.location.host}`
- return `${origin}/api/v1/auth/oauth/oidc/callback`
-})
-
-async function setAndCopyOIDCRedirectUrl() {
- const url = oidcRedirectUrlSuggestion.value
- if (!url) return
-
- form.oidc_connect_redirect_url = url
- await copyToClipboard(url, t('admin.settings.oidc.redirectUrlSetAndCopied'))
-}
-
// Custom menu item management
function addMenuItem() {
form.custom_menu_items.push({
@@ -2898,35 +2354,6 @@ function removeEndpoint(index: number) {
form.custom_endpoints.splice(index, 1)
}
-function formatTablePageSizeOptions(options: number[]): string {
- return options.join(', ')
-}
-
-function parseTablePageSizeOptionsInput(raw: string): number[] | null {
- const tokens = raw
- .split(',')
- .map((token) => token.trim())
- .filter((token) => token.length > 0)
-
- if (tokens.length === 0) {
- return null
- }
-
- const parsed = tokens.map((token) => Number(token))
- if (parsed.some((value) => !Number.isInteger(value))) {
- return null
- }
-
- const deduped = Array.from(new Set(parsed)).sort((a, b) => a - b)
- if (
- deduped.some((value) => value < tablePageSizeMin || value > tablePageSizeMax)
- ) {
- return null
- }
-
- return deduped
-}
-
async function loadSettings() {
loading.value = true
loadFailed.value = false
@@ -2951,15 +2378,11 @@ async function loadSettings() {
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
settings.registration_email_suffix_whitelist
)
- tablePageSizeOptionsInput.value = formatTablePageSizeOptions(
- Array.isArray(settings.table_page_size_options) ? settings.table_page_size_options : [10, 20, 50, 100]
- )
registrationEmailSuffixWhitelistDraft.value = ''
form.smtp_password = ''
smtpPasswordManuallyEdited.value = false
form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = ''
- form.oidc_connect_client_secret = ''
} catch (error: unknown) {
loadFailed.value = true
appStore.showError(extractApiErrorMessage(error, t('admin.settings.failedToLoad')))
@@ -2997,37 +2420,6 @@ function removeDefaultSubscription(index: number) {
async function saveSettings() {
saving.value = true
try {
- const normalizedTableDefaultPageSize = Math.floor(Number(form.table_default_page_size))
- if (
- !Number.isInteger(normalizedTableDefaultPageSize) ||
- normalizedTableDefaultPageSize < tablePageSizeMin ||
- normalizedTableDefaultPageSize > tablePageSizeMax
- ) {
- appStore.showError(
- t('admin.settings.site.tableDefaultPageSizeRangeError', {
- min: tablePageSizeMin,
- max: tablePageSizeMax
- })
- )
- return
- }
-
- const normalizedTablePageSizeOptions = parseTablePageSizeOptionsInput(
- tablePageSizeOptionsInput.value
- )
- if (!normalizedTablePageSizeOptions) {
- appStore.showError(
- t('admin.settings.site.tablePageSizeOptionsFormatError', {
- min: tablePageSizeMin,
- max: tablePageSizeMax
- })
- )
- return
- }
-
- form.table_default_page_size = normalizedTableDefaultPageSize
- form.table_page_size_options = normalizedTablePageSizeOptions
-
const normalizedDefaultSubscriptions = form.default_subscriptions
.filter((item) => item.group_id > 0 && item.validity_days > 0)
.map((item: DefaultSubscriptionSetting) => ({
@@ -3088,8 +2480,6 @@ async function saveSettings() {
home_content: form.home_content,
backend_mode_enabled: form.backend_mode_enabled,
hide_ccs_import_button: form.hide_ccs_import_button,
- table_default_page_size: form.table_default_page_size,
- table_page_size_options: form.table_page_size_options,
custom_menu_items: form.custom_menu_items,
custom_endpoints: form.custom_endpoints,
frontend_url: form.frontend_url,
@@ -3107,28 +2497,6 @@ async function saveSettings() {
linuxdo_connect_client_id: form.linuxdo_connect_client_id,
linuxdo_connect_client_secret: form.linuxdo_connect_client_secret || undefined,
linuxdo_connect_redirect_url: form.linuxdo_connect_redirect_url,
- oidc_connect_enabled: form.oidc_connect_enabled,
- oidc_connect_provider_name: form.oidc_connect_provider_name,
- oidc_connect_client_id: form.oidc_connect_client_id,
- oidc_connect_client_secret: form.oidc_connect_client_secret || undefined,
- oidc_connect_issuer_url: form.oidc_connect_issuer_url,
- oidc_connect_discovery_url: form.oidc_connect_discovery_url,
- oidc_connect_authorize_url: form.oidc_connect_authorize_url,
- oidc_connect_token_url: form.oidc_connect_token_url,
- oidc_connect_userinfo_url: form.oidc_connect_userinfo_url,
- oidc_connect_jwks_url: form.oidc_connect_jwks_url,
- oidc_connect_scopes: form.oidc_connect_scopes,
- oidc_connect_redirect_url: form.oidc_connect_redirect_url,
- oidc_connect_frontend_redirect_url: form.oidc_connect_frontend_redirect_url,
- oidc_connect_token_auth_method: form.oidc_connect_token_auth_method,
- oidc_connect_use_pkce: form.oidc_connect_use_pkce,
- oidc_connect_validate_id_token: form.oidc_connect_validate_id_token,
- oidc_connect_allowed_signing_algs: form.oidc_connect_allowed_signing_algs,
- oidc_connect_clock_skew_seconds: form.oidc_connect_clock_skew_seconds,
- oidc_connect_require_email_verified: form.oidc_connect_require_email_verified,
- oidc_connect_userinfo_email_path: form.oidc_connect_userinfo_email_path,
- oidc_connect_userinfo_id_path: form.oidc_connect_userinfo_id_path,
- oidc_connect_userinfo_username_path: form.oidc_connect_userinfo_username_path,
enable_model_fallback: form.enable_model_fallback,
fallback_model_anthropic: form.fallback_model_anthropic,
fallback_model_openai: form.fallback_model_openai,
@@ -3141,7 +2509,6 @@ async function saveSettings() {
allow_ungrouped_key_scheduling: form.allow_ungrouped_key_scheduling,
enable_fingerprint_unification: form.enable_fingerprint_unification,
enable_metadata_passthrough: form.enable_metadata_passthrough,
- enable_cch_signing: form.enable_cch_signing,
// Payment configuration
payment_enabled: form.payment_enabled,
payment_min_amount: Number(form.payment_min_amount) || 0,
@@ -3172,15 +2539,11 @@ async function saveSettings() {
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
updated.registration_email_suffix_whitelist
)
- tablePageSizeOptionsInput.value = formatTablePageSizeOptions(
- Array.isArray(updated.table_page_size_options) ? updated.table_page_size_options : [10, 20, 50, 100]
- )
registrationEmailSuffixWhitelistDraft.value = ''
form.smtp_password = ''
smtpPasswordManuallyEdited.value = false
form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = ''
- form.oidc_connect_client_secret = ''
// Refresh cached settings so sidebar/header update immediately
await appStore.fetchPublicSettings(true)
await adminSettingsStore.fetch(true)
@@ -3422,48 +2785,10 @@ const betaDisplayNames: Record = {
'context-1m-2025-08-07': 'Context 1M'
}
-// 快捷预设:按 beta_token 定义预设方案
-const betaPresets: Record> = {
- 'context-1m-2025-08-07': [
- {
- label: t('admin.settings.betaPolicy.presetOpusOnly'),
- description: t('admin.settings.betaPolicy.presetOpusOnlyDesc'),
- action: 'pass',
- model_whitelist: ['claude-opus-4-6'],
- fallback_action: 'filter',
- },
- ],
-}
-
-// 常用模型模式(具体 ID + 通配符示例)
-const commonModelPatterns = ['claude-opus-4-6', 'claude-sonnet-4-6', 'claude-opus-*', 'claude-sonnet-*']
-
function getBetaDisplayName(token: string): string {
return betaDisplayNames[token] || token
}
-function applyBetaPreset(
- rule: (typeof betaPolicyForm.rules)[number],
- preset: { action: 'pass' | 'filter' | 'block'; model_whitelist: string[]; fallback_action: 'pass' | 'filter' | 'block' }
-) {
- rule.action = preset.action
- rule.model_whitelist = [...preset.model_whitelist]
- rule.fallback_action = preset.fallback_action
-}
-
-function addQuickPattern(rule: (typeof betaPolicyForm.rules)[number], pattern: string) {
- if (!rule.model_whitelist) rule.model_whitelist = []
- if (!rule.model_whitelist.includes(pattern)) {
- rule.model_whitelist.push(pattern)
- }
-}
-
async function loadBetaPolicySettings() {
betaPolicyLoading.value = true
try {
@@ -3479,22 +2804,8 @@ async function loadBetaPolicySettings() {
async function saveBetaPolicySettings() {
betaPolicySaving.value = true
try {
- // Clean up empty patterns before saving
- const cleanedRules = betaPolicyForm.rules.map(rule => {
- const whitelist = rule.model_whitelist?.filter(p => p.trim() !== '')
- const hasWhitelist = whitelist && whitelist.length > 0
- return {
- beta_token: rule.beta_token,
- action: rule.action,
- scope: rule.scope,
- error_message: rule.error_message,
- model_whitelist: hasWhitelist ? whitelist : undefined,
- fallback_action: hasWhitelist ? (rule.fallback_action || 'pass') : undefined,
- fallback_error_message: hasWhitelist && rule.fallback_action === 'block' ? rule.fallback_error_message : undefined,
- }
- })
const updated = await adminAPI.settings.updateBetaPolicySettings({
- rules: cleanedRules
+ rules: betaPolicyForm.rules
})
betaPolicyForm.rules = updated.rules
appStore.showSuccess(t('admin.settings.betaPolicy.saved'))
@@ -3608,15 +2919,15 @@ async function handleSaveProvider(payload: Partial) {
providerSaving.value = true
try {
if (editingProvider.value) {
- await adminAPI.payment.updateProvider(editingProvider.value.id, payload)
+ const updated = await adminAPI.payment.updateProvider(editingProvider.value.id, payload)
+ // Update in place to preserve list order
+ const idx = providers.value.findIndex(p => p.id === editingProvider.value!.id)
+ if (idx >= 0 && updated.data) providers.value[idx] = updated.data
} else {
await adminAPI.payment.createProvider(payload)
+ loadProviders()
}
showProviderDialog.value = false
- // Reload full list (API returns decrypted/formatted data with correct sort order)
- await loadProviders()
- // Auto-save settings so provider changes take effect immediately
- await saveSettings()
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value))
} finally {
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index bc16918c..cf0bf373 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -102,12 +102,10 @@ interface ReturnInfo {
}
const returnInfo = ref(null)
-const SUCCESS_STATUSES = new Set(['COMPLETED', 'PAID', 'RECHARGING'])
-
const isSuccess = computed(() => {
// Always prioritize actual order status from backend
if (order.value) {
- return SUCCESS_STATUSES.has(order.value.status)
+ return order.value.status === 'COMPLETED' || order.value.status === 'PAID'
}
// Fallback only when order not loaded
if (route.query.status === 'success') return true
@@ -138,17 +136,14 @@ onMounted(async () => {
}
}
- // Verify payment via public endpoint (works without login)
+ // If we have an out_trade_no from a provider return URL, actively verify
+ // the payment with the upstream provider (handles missed notify callbacks)
if (outTradeNo) {
try {
- const result = await paymentAPI.verifyOrderPublic(outTradeNo)
+ const result = await paymentAPI.verifyOrder(outTradeNo)
order.value = result.data
} catch (_err: unknown) {
- // Public verify failed, try authenticated endpoint if logged in
- try {
- const result = await paymentAPI.verifyOrder(outTradeNo)
- order.value = result.data
- } catch (_e: unknown) { /* fall through */ }
+ // Verification failed, fall through to normal order lookup
}
}
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index 5dc396ec..db588420 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -374,7 +374,7 @@ onMounted(async () => {
if (checkout.value.balance_disabled) {
activeTab.value = 'subscription'
}
- } catch (err: unknown) { console.error('Failed to load checkout info:', err) }
+ } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) }
finally { loading.value = false }
})
From c738cfec93ab3ed27d346a28b3dc9b84885b5d4e Mon Sep 17 00:00:00 2001
From: erio
Date: Fri, 10 Apr 2026 02:23:19 +0800
Subject: [PATCH 13/88] fix(payment): critical audit fixes for security,
idempotency and correctness
Backend fixes:
- #1: doSub subscription idempotency via audit log check
- #2: markFailed only when status=RECHARGING (prevents overwriting COMPLETED)
- #3: ExpireTimedOutOrders checks upstream payment before expiring
- #4: Public verify endpoint for payment result page (no auth required)
- #5: EasyPay QueryOrder returns amount, confirmPayment handles zero amount
- #6: WxPay notifyUrl priority: request-first, config-fallback
- #7: EasyPay remove double URL decode in VerifyNotification
- #8: checkPaid/cancelUpstreamPayment use order's provider instance
- #9: Amount NaN/Inf/negative validation in order creation and refund
- #10: Refund amount comparison uses tolerance instead of float64 ==
- #11: Skip balance deduction on retry when previous rollback failed
- #12: checkPaid logs fulfillment errors instead of silently ignoring
- #13: WxPay certSerial added to required config fields
Frontend fixes:
- Payment result page no longer requires authentication
- Public verify API fallback for expired sessions
---
backend/internal/payment/provider/easypay.go | 7 +-
.../internal/payment/provider/wxpay_test.go | 1 +
backend/internal/server/routes/payment.go | 8 ++
.../internal/service/payment_fulfillment.go | 50 ++++++++---
backend/internal/service/payment_order.go | 87 +++++++++++++------
backend/internal/service/payment_refund.go | 42 +++++++--
frontend/src/views/user/PaymentResultView.vue | 11 ++-
7 files changed, 152 insertions(+), 54 deletions(-)
diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go
index 3fa59283..c54aba6a 100644
--- a/backend/internal/payment/provider/easypay.go
+++ b/backend/internal/payment/provider/easypay.go
@@ -158,6 +158,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
Code int `json:"code"`
Msg string `json:"msg"`
Status int `json:"status"`
+ Money string `json:"money"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("easypay parse query: %w", err)
@@ -166,7 +167,8 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
if resp.Status == easypayStatusPaid {
status = payment.ProviderStatusPaid
}
- return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status}, nil
+ amount, _ := strconv.ParseFloat(resp.Money, 64)
+ return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
}
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
@@ -174,9 +176,10 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
if err != nil {
return nil, fmt.Errorf("parse notify: %w", err)
}
+ // url.ParseQuery already decodes values — no additional decode needed.
params := make(map[string]string)
for k := range values {
- params[k] = decodeURLValue(values.Get(k))
+ params[k] = values.Get(k)
}
sign := params["sign"]
if sign == "" {
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index 4b774d63..b8b99537 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -156,6 +156,7 @@ func TestNewWxpay(t *testing.T) {
"apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes
"publicKey": "fake-public-key",
"publicKeyId": "key-id-001",
+ "certSerial": "SERIAL001",
}
// helper to clone and override config fields
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
index 828b68f3..641c6cd5 100644
--- a/backend/internal/server/routes/payment.go
+++ b/backend/internal/server/routes/payment.go
@@ -40,6 +40,14 @@ func RegisterPaymentRoutes(
}
}
+ // --- Public payment endpoints (no auth) ---
+ // Payment result page needs to verify order status without login
+ // (user session may have expired during provider redirect).
+ public := v1.Group("/payment/public")
+ {
+ public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
+ }
+
// --- Webhook endpoints (no auth) ---
webhook := v1.Group("/payment/webhook")
{
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 7dd6d835..51307849 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -8,6 +8,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -32,9 +33,17 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
slog.Error("order not found", "orderID", oid)
return nil
}
- if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
- s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
- return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
+ // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
+ // Also skip if paid is NaN/Inf (malformed provider data).
+ if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
+ if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
+ return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
+ }
+ }
+ // Use order's expected amount when provider didn't report one
+ if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
+ paid = o.PayAmount
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
@@ -241,27 +250,42 @@ func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error
if err != nil || g.Status != payment.EntityStatusActive {
return fmt.Errorf("group %d no longer exists or inactive", gid)
}
- _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: fmt.Sprintf("payment order %d", o.ID)})
+ // Idempotency: check audit log to see if subscription was already assigned.
+ // Prevents double-extension on retry after markCompleted fails.
+ if s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") {
+ slog.Info("subscription already assigned for order, skipping", "orderID", o.ID, "groupID", gid)
+ return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
+ }
+ orderNote := fmt.Sprintf("payment order %d", o.ID)
+ _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote})
if err != nil {
return fmt.Errorf("assign subscription: %w", err)
}
- now := time.Now()
- _, err = s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusRecharging)).SetStatus(OrderStatusCompleted).SetCompletedAt(now).Save(ctx)
- if err != nil {
- return fmt.Errorf("mark completed: %w", err)
- }
- s.writeAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS", "system", map[string]any{"groupId": gid, "days": days, "amount": o.Amount})
- return nil
+ return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS")
+}
+
+func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action string) bool {
+ oid := strconv.FormatInt(orderID, 10)
+ c, _ := s.entClient.PaymentAuditLog.Query().
+ Where(paymentauditlog.OrderIDEQ(oid), paymentauditlog.ActionEQ(action)).
+ Limit(1).Count(ctx)
+ return c > 0
}
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)
- _, e := s.entClient.PaymentOrder.UpdateOneID(oid).SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx)
+ // Only mark FAILED if still in RECHARGING state — prevents overwriting
+ // a COMPLETED order when markCompleted failed but fulfillment succeeded.
+ c, e := s.entClient.PaymentOrder.Update().
+ Where(paymentorder.IDEQ(oid), paymentorder.StatusEQ(OrderStatusRecharging)).
+ SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx)
if e != nil {
slog.Error("mark FAILED", "orderID", oid, "error", e)
}
- s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r})
+ if c > 0 {
+ s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r})
+ }
}
func (s *PaymentService) RetryFulfillment(ctx context.Context, oid int64) error {
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index d61a0d88..e81af3f5 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -72,6 +72,9 @@ func (s *PaymentService) validateOrderInput(ctx context.Context, req CreateOrder
if req.OrderType == payment.OrderTypeSubscription {
return s.validateSubOrder(ctx, req)
}
+ if math.IsNaN(req.Amount) || math.IsInf(req.Amount, 0) || req.Amount <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount must be a positive number")
+ }
if (cfg.MinAmount > 0 && req.Amount < cfg.MinAmount) || (cfg.MaxAmount > 0 && req.Amount > cfg.MaxAmount) {
return nil, infraerrors.BadRequest("INVALID_AMOUNT", "amount out of range").
WithMetadata(map[string]string{"min": fmt.Sprintf("%.2f", cfg.MinAmount), "max": fmt.Sprintf("%.2f", cfg.MaxAmount)})
@@ -394,7 +397,7 @@ func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (s
}
func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) {
- if o.PaymentTradeNo != "" && o.PaymentType != "" {
+ if o.PaymentTradeNo != "" || o.PaymentType != "" {
if s.checkPaid(ctx, o) == "already_paid" {
return "already_paid", nil
}
@@ -404,14 +407,17 @@ func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder,
return "", fmt.Errorf("update order status: %w", err)
}
if c > 0 {
- s.writeAuditLog(ctx, o.ID, "ORDER_CANCELLED", op, map[string]any{"detail": ad})
+ auditAction := "ORDER_CANCELLED"
+ if fs == OrderStatusExpired {
+ auditAction = "ORDER_EXPIRED"
+ }
+ s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad})
}
return "cancelled", nil
}
func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string {
- s.EnsureProviders(ctx)
- prov, err := s.registry.GetProvider(o.PaymentType)
+ prov, err := s.getOrderProvider(ctx, o)
if err != nil {
return ""
}
@@ -427,11 +433,14 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
if resp.Status == payment.ProviderStatusPaid {
- _ = s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey())
+ if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
+ slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
+ // Still return already_paid — order was paid, fulfillment can be retried
+ }
return "already_paid"
}
if cp, ok := prov.(payment.CancelableProvider); ok {
- _ = cp.CancelPayment(ctx, o.PaymentTradeNo)
+ _ = cp.CancelPayment(ctx, tradeNo)
}
return ""
}
@@ -463,6 +472,27 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
return o, nil
}
+// VerifyOrderPublic verifies payment status without user authentication.
+// Used by the payment result page when the user's session has expired.
+func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
+ o, err := s.entClient.PaymentOrder.Query().
+ Where(paymentorder.OutTradeNo(outTradeNo)).
+ Only(ctx)
+ if err != nil {
+ return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
+ }
+ if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
+ result := s.checkPaid(ctx, o)
+ if result == "already_paid" {
+ o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
+ if err != nil {
+ return nil, fmt.Errorf("reload order: %w", err)
+ }
+ }
+ }
+ return o, nil
+}
+
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
@@ -471,34 +501,39 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error)
}
n := 0
for _, o := range orders {
- // Cancel upstream payment (e.g. Stripe PaymentIntent) before marking expired
- s.cancelUpstreamPayment(ctx, o)
- c, e := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(OrderStatusExpired).Save(ctx)
- if e != nil {
- slog.Warn("expire failed", "orderID", o.ID, "error", e)
+ // Check upstream payment status before expiring — the user may have
+ // paid just before timeout and the webhook hasn't arrived yet.
+ outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired")
+ if outcome == "already_paid" {
+ slog.Info("order was paid during expiry", "orderID", o.ID)
continue
}
- if c > 0 {
- s.writeAuditLog(ctx, o.ID, "ORDER_EXPIRED", "system", map[string]any{"expiresAt": o.ExpiresAt.Format(time.RFC3339)})
+ if outcome != "" {
n++
}
}
return n, nil
}
-// cancelUpstreamPayment attempts to cancel the upstream provider payment (e.g. Stripe PaymentIntent).
-func (s *PaymentService) cancelUpstreamPayment(ctx context.Context, o *dbent.PaymentOrder) {
- if o.PaymentTradeNo == "" || o.PaymentType == "" {
- return
- }
- s.EnsureProviders(ctx)
- prov, err := s.registry.GetProvider(o.PaymentType)
- if err != nil {
- return
- }
- if cp, ok := prov.(payment.CancelableProvider); ok {
- if err := cp.CancelPayment(ctx, o.PaymentTradeNo); err != nil {
- slog.Warn("cancel upstream payment failed", "orderID", o.ID, "tradeNo", o.PaymentTradeNo, "error", err)
+// getOrderProvider creates a provider using the order's original instance config.
+// Falls back to registry lookup if instance ID is missing (legacy orders).
+func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
+ if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
+ instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
+ if err == nil {
+ cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
+ if err == nil {
+ providerKey := s.registry.GetProviderKey(o.PaymentType)
+ if providerKey == "" {
+ providerKey = o.PaymentType
+ }
+ p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
+ if err == nil {
+ return p, nil
+ }
+ }
}
}
+ s.EnsureProviders(ctx)
+ return s.registry.GetProvider(o.PaymentType)
}
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index f3d20509..68f9c697 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -69,14 +69,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
+ if math.IsNaN(amt) || math.IsInf(amt, 0) {
+ return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
+ }
if amt <= 0 {
amt = o.Amount
}
- if amt > o.Amount {
+ if amt-o.Amount > amountToleranceCNY {
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
}
+ // Full refund: use actual pay_amount for gateway (includes fees)
ga := amt
- if amt == o.Amount {
+ if math.Abs(amt-o.Amount) <= amountToleranceCNY {
ga = o.PayAmount
}
rr := strings.TrimSpace(reason)
@@ -121,9 +125,16 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
return nil, infraerrors.Conflict("CONFLICT", "order status changed")
}
if p.DeductionType == payment.DeductionTypeBalance && p.BalanceToDeduct > 0 {
- if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
- s.restoreStatus(ctx, p)
- return nil, fmt.Errorf("deduction: %w", err)
+ // Skip balance deduction on retry if previous attempt already deducted
+ // but failed to roll back (REFUND_ROLLBACK_FAILED in audit log).
+ if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
+ if err := s.userRepo.DeductBalance(ctx, p.Order.UserID, p.BalanceToDeduct); err != nil {
+ s.restoreStatus(ctx, p)
+ return nil, fmt.Errorf("deduction: %w", err)
+ }
+ } else {
+ slog.Warn("skipping balance deduction on retry (previous rollback failed)", "orderID", p.OrderID)
+ p.BalanceToDeduct = 0
}
}
if err := s.gwRefund(ctx, p); err != nil {
@@ -137,15 +148,28 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
s.writeAuditLog(ctx, p.Order.ID, "REFUND_NO_TRADE_NO", "admin", map[string]any{"detail": "skipped"})
return nil
}
- s.EnsureProviders(ctx)
- prov, err := s.registry.GetProvider(p.Order.PaymentType)
+
+ // Use the exact provider instance that created this order, not a random one
+ // from the registry. Each instance has its own merchant credentials.
+ prov, err := s.getRefundProvider(ctx, p.Order)
if err != nil {
- return fmt.Errorf("get provider: %w", err)
+ return fmt.Errorf("get refund provider: %w", err)
}
- _, err = prov.Refund(ctx, payment.RefundRequest{TradeNo: p.Order.PaymentTradeNo, OrderID: p.Order.OutTradeNo, Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64), Reason: p.Reason})
+ _, err = prov.Refund(ctx, payment.RefundRequest{
+ TradeNo: p.Order.PaymentTradeNo,
+ OrderID: p.Order.OutTradeNo,
+ Amount: strconv.FormatFloat(p.GatewayAmount, 'f', 2, 64),
+ Reason: p.Reason,
+ })
return err
}
+// getRefundProvider creates a provider using the order's original instance config.
+// Delegates to getOrderProvider which handles instance lookup and fallback.
+func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
+ return s.getOrderProvider(ctx, o)
+}
+
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
if s.RollbackRefund(ctx, p, gErr) {
s.restoreStatus(ctx, p)
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index cf0bf373..3c7df572 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -136,14 +136,17 @@ onMounted(async () => {
}
}
- // If we have an out_trade_no from a provider return URL, actively verify
- // the payment with the upstream provider (handles missed notify callbacks)
+ // Verify payment via public endpoint (works without login)
if (outTradeNo) {
try {
- const result = await paymentAPI.verifyOrder(outTradeNo)
+ const result = await paymentAPI.verifyOrderPublic(outTradeNo)
order.value = result.data
} catch (_err: unknown) {
- // Verification failed, fall through to normal order lookup
+ // Public verify failed, try authenticated endpoint if logged in
+ try {
+ const result = await paymentAPI.verifyOrder(outTradeNo)
+ order.value = result.data
+ } catch (_e: unknown) { /* fall through */ }
}
}
From 1b53ffcac71242840a54ef102977e94e204aefc1 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 00:02:26 +0800
Subject: [PATCH 14/88] feat(gateway): add web search emulation for Anthropic
API Key accounts
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Inject web search capability for Claude Console (API Key) accounts that
don't natively support Anthropic's web_search tool. When a pure
web_search request is detected, the gateway calls Brave Search or Tavily
API directly and constructs an Anthropic-protocol-compliant SSE/JSON
response without forwarding to upstream.
Backend:
- New `pkg/websearch/` SDK: Brave and Tavily provider implementations
with io.LimitReader, proxy support, and Redis-based quota tracking
(Lua atomic INCR + TTL, DECR rollback on failure)
- Global config via `settings.web_search_emulation_config` (JSON) with
in-process cache + singleflight, input validation, API key merge on
save, and sanitized API responses
- Channel-level toggle via `channels.features_config` JSONB column
(DB migration 101)
- Account-level toggle via `accounts.extra.web_search_emulation`
- Request interception in `Forward()` with SSE streaming response
construction using json.Marshal (no manual string concatenation)
- Manager hot-reload: `RebuildWebSearchManager()` called on config save
and startup via `SetWebSearchRedisClient()`
- 70 unit tests covering providers, manager, config validation,
sanitization, tool detection, query extraction, and response building
Frontend:
- Settings → Gateway tab: Web Search Emulation config card with global
toggle, provider list (add/remove, API key, priority, quota, proxy)
- Channels → Anthropic tab: web search emulation toggle with global
state linkage (disabled when global off)
- Account Create/Edit modals: web search emulation toggle for API Key
type with Toggle component
- Full i18n coverage (zh + en)
---
.../internal/handler/admin/channel_handler.go | 6 +
.../internal/handler/admin/setting_handler.go | 35 +
backend/internal/handler/dto/settings.go | 3 +
backend/internal/pkg/websearch/brave.go | 106 ++
backend/internal/pkg/websearch/brave_test.go | 119 +++
backend/internal/pkg/websearch/helpers.go | 14 +
.../internal/pkg/websearch/helpers_test.go | 25 +
backend/internal/pkg/websearch/manager.go | 273 ++++++
.../internal/pkg/websearch/manager_test.go | 149 +++
backend/internal/pkg/websearch/provider.go | 11 +
backend/internal/pkg/websearch/tavily.go | 107 ++
backend/internal/pkg/websearch/tavily_test.go | 63 ++
backend/internal/pkg/websearch/types.go | 30 +
backend/internal/repository/channel_repo.go | 90 +-
backend/internal/server/routes/admin.go | 3 +
backend/internal/service/account.go | 19 +-
.../service/account_websearch_test.go | 71 ++
backend/internal/service/channel.go | 15 +
backend/internal/service/channel_service.go | 389 +++++---
.../service/channel_websearch_test.go | 62 ++
backend/internal/service/domain_constants.go | 4 +
backend/internal/service/gateway_service.go | 5 +
.../service/gateway_websearch_emulation.go | 358 +++++++
.../gateway_websearch_emulation_test.go | 142 +++
backend/internal/service/setting_service.go | 10 +
backend/internal/service/settings_view.go | 3 +
backend/internal/service/websearch_config.go | 253 +++++
.../internal/service/websearch_config_test.go | 148 +++
.../101_add_channel_features_config.sql | 2 +
frontend/src/api/admin/channels.ts | 3 +
frontend/src/api/admin/settings.ts | 40 +-
.../components/account/CreateAccountModal.vue | 26 +
.../components/account/EditAccountModal.vue | 90 +-
frontend/src/i18n/locales/en.ts | 33 +-
frontend/src/i18n/locales/zh.ts | 33 +-
frontend/src/views/admin/ChannelsView.vue | 94 +-
frontend/src/views/admin/SettingsView.vue | 911 +++++++++++++++++-
37 files changed, 3507 insertions(+), 238 deletions(-)
create mode 100644 backend/internal/pkg/websearch/brave.go
create mode 100644 backend/internal/pkg/websearch/brave_test.go
create mode 100644 backend/internal/pkg/websearch/helpers.go
create mode 100644 backend/internal/pkg/websearch/helpers_test.go
create mode 100644 backend/internal/pkg/websearch/manager.go
create mode 100644 backend/internal/pkg/websearch/manager_test.go
create mode 100644 backend/internal/pkg/websearch/provider.go
create mode 100644 backend/internal/pkg/websearch/tavily.go
create mode 100644 backend/internal/pkg/websearch/tavily_test.go
create mode 100644 backend/internal/pkg/websearch/types.go
create mode 100644 backend/internal/service/account_websearch_test.go
create mode 100644 backend/internal/service/channel_websearch_test.go
create mode 100644 backend/internal/service/gateway_websearch_emulation.go
create mode 100644 backend/internal/service/gateway_websearch_emulation_test.go
create mode 100644 backend/internal/service/websearch_config.go
create mode 100644 backend/internal/service/websearch_config_test.go
create mode 100644 backend/migrations/101_add_channel_features_config.sql
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
index d6022283..9cefc792 100644
--- a/backend/internal/handler/admin/channel_handler.go
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -34,6 +34,7 @@ type createChannelRequest struct {
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
}
type updateChannelRequest struct {
@@ -46,6 +47,7 @@ type updateChannelRequest struct {
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"`
Features *string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
}
type channelModelPricingRequest struct {
@@ -81,6 +83,7 @@ type channelResponse struct {
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
@@ -126,6 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Status: ch.Status,
RestrictModels: ch.RestrictModels,
Features: ch.Features,
+ FeaturesConfig: ch.FeaturesConfig,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
@@ -305,6 +309,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
Features: req.Features,
+ FeaturesConfig: req.FeaturesConfig,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -338,6 +343,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
Features: req.Features,
+ FeaturesConfig: req.FeaturesConfig,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index ba751131..031b819a 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -175,6 +175,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
+ WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
@@ -1847,3 +1848,37 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}
+
+// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
+// GET /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
+ cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.SanitizeWebSearchConfig(cfg))
+}
+
+// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
+// PUT /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
+ var cfg service.WebSearchEmulationConfig
+ if err := c.ShouldBindJSON(&cfg); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Re-read (with sanitized api keys) to return current state
+ updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.SanitizeWebSearchConfig(updated))
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index cbbe9216..0433d692 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -124,6 +124,9 @@ type SystemSettings struct {
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"`
+ // Web Search Emulation
+ WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
+
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
diff --git a/backend/internal/pkg/websearch/brave.go b/backend/internal/pkg/websearch/brave.go
new file mode 100644
index 00000000..5620ca8d
--- /dev/null
+++ b/backend/internal/pkg/websearch/brave.go
@@ -0,0 +1,106 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+)
+
+const (
+ braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
+ braveMaxCount = 20
+ braveProviderName = "brave"
+)
+
+// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
+var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
+
+// BraveProvider implements web search via the Brave Search API.
+type BraveProvider struct {
+ apiKey string
+ httpClient *http.Client
+}
+
+// NewBraveProvider creates a Brave Search provider.
+// The caller is responsible for configuring the http.Client with proxy/timeouts.
+func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
+}
+
+func (b *BraveProvider) Name() string { return braveProviderName }
+
+func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
+ count := req.MaxResults
+ if count <= 0 {
+ count = defaultMaxResults
+ }
+ if count > braveMaxCount {
+ count = braveMaxCount
+ }
+
+ u := *braveSearchURL // copy the pre-parsed URL
+ q := u.Query()
+ q.Set("q", req.Query)
+ q.Set("count", strconv.Itoa(count))
+ u.RawQuery = q.Encode()
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("brave: build request: %w", err)
+ }
+ httpReq.Header.Set("X-Subscription-Token", b.apiKey)
+ httpReq.Header.Set("Accept", "application/json")
+
+ resp, err := b.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("brave: request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
+ if err != nil {
+ return nil, fmt.Errorf("brave: read body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
+ }
+
+ var raw braveResponse
+ if err := json.Unmarshal(body, &raw); err != nil {
+ return nil, fmt.Errorf("brave: decode response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(raw.Web.Results))
+ for _, r := range raw.Web.Results {
+ results = append(results, SearchResult{
+ URL: r.URL,
+ Title: r.Title,
+ Snippet: r.Description,
+ PageAge: r.Age,
+ })
+ }
+
+ return &SearchResponse{Results: results, Query: req.Query}, nil
+}
+
+// braveResponse is the minimal structure of the Brave Search API response.
+type braveResponse struct {
+ Web struct {
+ Results []braveResult `json:"results"`
+ } `json:"web"`
+}
+
+type braveResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Description string `json:"description"`
+ Age string `json:"age"`
+}
diff --git a/backend/internal/pkg/websearch/brave_test.go b/backend/internal/pkg/websearch/brave_test.go
new file mode 100644
index 00000000..3fe35020
--- /dev/null
+++ b/backend/internal/pkg/websearch/brave_test.go
@@ -0,0 +1,119 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestBraveProvider_Name(t *testing.T) {
+ p := NewBraveProvider("key", nil)
+ require.Equal(t, "brave", p.Name())
+}
+
+func TestBraveProvider_Search_Success(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
+ require.Equal(t, "application/json", r.Header.Get("Accept"))
+ require.Equal(t, "golang", r.URL.Query().Get("q"))
+ require.Equal(t, "3", r.URL.Query().Get("count"))
+
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{
+ {URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
+ {URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
+ {URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("test-key", srv.Client())
+ // Override the endpoint for testing
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
+ require.NoError(t, err)
+ require.Len(t, resp.Results, 3)
+ require.Equal(t, "https://go.dev", resp.Results[0].URL)
+ require.Equal(t, "Go lang", resp.Results[0].Snippet)
+ require.Equal(t, "1 day", resp.Results[0].PageAge)
+}
+
+func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
+ var receivedCount string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ receivedCount = r.URL.Query().Get("count")
+ resp := braveResponse{}
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
+ require.Equal(t, "5", receivedCount)
+}
+
+func TestBraveProvider_Search_HTTPError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(429)
+ w.Write([]byte("rate limited"))
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "brave: status 429")
+}
+
+func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Write([]byte("not json"))
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "brave: decode response")
+}
+
+func TestBraveProvider_Search_EmptyResults(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Empty(t, resp.Results)
+}
diff --git a/backend/internal/pkg/websearch/helpers.go b/backend/internal/pkg/websearch/helpers.go
new file mode 100644
index 00000000..0d08b749
--- /dev/null
+++ b/backend/internal/pkg/websearch/helpers.go
@@ -0,0 +1,14 @@
+package websearch
+
+const (
+ maxResponseSize = 1 << 20 // 1 MB
+ errorBodyTruncLen = 200
+)
+
+// truncateBody returns a truncated string of body for error messages.
+func truncateBody(body []byte) string {
+ if len(body) <= errorBodyTruncLen {
+ return string(body)
+ }
+ return string(body[:errorBodyTruncLen]) + "...(truncated)"
+}
diff --git a/backend/internal/pkg/websearch/helpers_test.go b/backend/internal/pkg/websearch/helpers_test.go
new file mode 100644
index 00000000..e3164329
--- /dev/null
+++ b/backend/internal/pkg/websearch/helpers_test.go
@@ -0,0 +1,25 @@
+package websearch
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTruncateBody_Short(t *testing.T) {
+ body := []byte("short body")
+ require.Equal(t, "short body", truncateBody(body))
+}
+
+func TestTruncateBody_Long(t *testing.T) {
+ body := []byte(strings.Repeat("x", 500))
+ result := truncateBody(body)
+ require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
+ require.True(t, strings.HasSuffix(result, "...(truncated)"))
+}
+
+func TestTruncateBody_ExactBoundary(t *testing.T) {
+ body := []byte(strings.Repeat("x", errorBodyTruncLen))
+ require.Equal(t, string(body), truncateBody(body))
+}
diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go
new file mode 100644
index 00000000..95da70e4
--- /dev/null
+++ b/backend/internal/pkg/websearch/manager.go
@@ -0,0 +1,273 @@
+package websearch
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+)
+
+// Quota refresh interval constants.
+const (
+ QuotaRefreshDaily = "daily"
+ QuotaRefreshWeekly = "weekly"
+ QuotaRefreshMonthly = "monthly"
+)
+
+// ProviderConfig holds the configuration for a single search provider.
+type ProviderConfig struct {
+ Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
+ APIKey string `json:"api_key"` // secret
+ Priority int `json:"priority"` // lower = higher priority
+ QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
+ QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly
+ ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
+ ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
+}
+
+// Manager selects providers by priority and tracks quota via Redis.
+type Manager struct {
+ configs []ProviderConfig
+ redis *redis.Client
+
+ clientMu sync.Mutex
+ clientCache map[string]*http.Client
+}
+
+const (
+ quotaKeyPrefix = "websearch:quota:"
+ searchRequestTimeout = 30 * time.Second
+ quotaTTLBuffer = 24 * time.Hour
+ maxCachedClients = 100
+)
+
+// quotaIncrScript atomically increments the counter and sets TTL on first creation.
+// KEYS[1] = quota key, ARGV[1] = TTL in seconds.
+// Returns the new counter value.
+var quotaIncrScript = redis.NewScript(`
+local val = redis.call('INCR', KEYS[1])
+if val == 1 then
+ redis.call('EXPIRE', KEYS[1], ARGV[1])
+else
+ -- Defensive: ensure TTL exists even if a prior EXPIRE failed
+ local ttl = redis.call('TTL', KEYS[1])
+ if ttl == -1 then
+ redis.call('EXPIRE', KEYS[1], ARGV[1])
+ end
+end
+return val
+`)
+
+// NewManager creates a Manager with the given provider configs and Redis client.
+func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
+ sorted := make([]ProviderConfig, len(configs))
+ copy(sorted, configs)
+ sort.Slice(sorted, func(i, j int) bool {
+ return sorted[i].Priority < sorted[j].Priority
+ })
+ return &Manager{
+ configs: sorted,
+ redis: redisClient,
+ clientCache: make(map[string]*http.Client),
+ }
+}
+
+// SearchWithBestProvider selects the highest-priority available provider,
+// reserves quota, executes the search, and rolls back quota on failure.
+func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
+ if strings.TrimSpace(req.Query) == "" {
+ return nil, "", fmt.Errorf("websearch: empty search query")
+ }
+ for _, cfg := range m.configs {
+ if !m.isProviderAvailable(cfg) {
+ continue
+ }
+ allowed, incremented := m.tryReserveQuota(ctx, cfg)
+ if !allowed {
+ continue
+ }
+ resp, err := m.executeSearch(ctx, cfg, req)
+ if err != nil {
+ if incremented {
+ m.rollbackQuota(ctx, cfg)
+ }
+ slog.Warn("websearch: provider search failed",
+ "provider", cfg.Type, "error", err)
+ continue
+ }
+ return resp, cfg.Type, nil
+ }
+ return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
+}
+
+func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
+ if cfg.APIKey == "" {
+ return false
+ }
+ if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
+ slog.Info("websearch: provider expired, skipping",
+ "provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
+ return false
+ }
+ return true
+}
+
+// tryReserveQuota atomically increments the counter via Lua script and checks limit.
+// Returns (allowed, incremented): allowed=true means the request may proceed;
+// incremented=true means the Redis counter was actually incremented (so rollback is needed on failure).
+func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
+ if cfg.QuotaLimit <= 0 {
+ return true, false // unlimited, no INCR
+ }
+ if m.redis == nil {
+ slog.Warn("websearch: Redis unavailable, quota check skipped",
+ "provider", cfg.Type)
+ return true, false // allowed but not incremented
+ }
+ key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
+ ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds())
+
+ newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
+ if err != nil {
+ slog.Warn("websearch: quota Lua INCR failed, allowing request",
+ "provider", cfg.Type, "error", err)
+ return true, false // allowed but not incremented
+ }
+ if newVal > cfg.QuotaLimit {
+ if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
+ slog.Warn("websearch: quota over-limit DECR failed",
+ "provider", cfg.Type, "error", decrErr)
+ }
+ slog.Info("websearch: provider quota exhausted",
+ "provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
+ return false, false // rejected, already rolled back
+ }
+ return true, true // allowed and incremented
+}
+
+// rollbackQuota decrements the counter after a search failure.
+func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
+ if cfg.QuotaLimit <= 0 || m.redis == nil {
+ return
+ }
+ key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
+ if err := m.redis.Decr(ctx, key).Err(); err != nil {
+ slog.Warn("websearch: quota rollback DECR failed",
+ "provider", cfg.Type, "error", err)
+ }
+}
+
+func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
+ proxyURL := cfg.ProxyURL
+ if req.ProxyURL != "" {
+ proxyURL = req.ProxyURL
+ }
+ client := m.getOrCreateHTTPClient(proxyURL)
+ provider := m.buildProvider(cfg, client)
+ return provider.Search(ctx, req)
+}
+
+// GetUsage returns the current usage count for the given provider.
+func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) {
+ if m.redis == nil {
+ return 0, nil
+ }
+ key := quotaRedisKey(providerType, refreshInterval)
+ val, err := m.redis.Get(ctx, key).Int64()
+ if err == redis.Nil {
+ return 0, nil
+ }
+ return val, err
+}
+
+// GetAllUsage returns usage for every configured provider.
+func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
+ result := make(map[string]int64, len(m.configs))
+ for _, cfg := range m.configs {
+ used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
+ result[cfg.Type] = used
+ }
+ return result
+}
+
+// --- HTTP client cache (bounded) ---
+
+func (m *Manager) getOrCreateHTTPClient(proxyURL string) *http.Client {
+ m.clientMu.Lock()
+ defer m.clientMu.Unlock()
+
+ if c, ok := m.clientCache[proxyURL]; ok {
+ return c
+ }
+ if len(m.clientCache) >= maxCachedClients {
+ m.clientCache = make(map[string]*http.Client) // evict all
+ }
+ c := newHTTPClient(proxyURL)
+ m.clientCache[proxyURL] = c
+ return c
+}
+
+func newHTTPClient(proxyURL string) *http.Client {
+ transport := &http.Transport{
+ TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
+ }
+ if proxyURL != "" {
+ if u, err := url.Parse(proxyURL); err == nil {
+ transport.Proxy = http.ProxyURL(u)
+ }
+ }
+ return &http.Client{Transport: transport, Timeout: searchRequestTimeout}
+}
+
+// --- Provider factory ---
+
+func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
+ switch cfg.Type {
+ case braveProviderName:
+ return NewBraveProvider(cfg.APIKey, client)
+ case tavilyProviderName:
+ return NewTavilyProvider(cfg.APIKey, client)
+ default:
+ slog.Warn("websearch: unknown provider type, falling back to brave",
+ "type", cfg.Type)
+ return NewBraveProvider(cfg.APIKey, client)
+ }
+}
+
+// --- Redis key helpers ---
+
+func quotaRedisKey(providerType, refreshInterval string) string {
+ return quotaKeyPrefix + providerType + ":" + periodKey(refreshInterval)
+}
+
+func periodKey(refreshInterval string) string {
+ now := time.Now().UTC()
+ switch refreshInterval {
+ case QuotaRefreshDaily:
+ return now.Format("2006-01-02")
+ case QuotaRefreshWeekly:
+ year, week := now.ISOWeek()
+ return fmt.Sprintf("%d-W%02d", year, week)
+ default: // QuotaRefreshMonthly
+ return now.Format("2006-01")
+ }
+}
+
+func quotaTTL(refreshInterval string) time.Duration {
+ switch refreshInterval {
+ case QuotaRefreshDaily:
+ return 24*time.Hour + quotaTTLBuffer
+ case QuotaRefreshWeekly:
+ return 7*24*time.Hour + quotaTTLBuffer
+ default:
+ return 31*24*time.Hour + quotaTTLBuffer
+ }
+}
diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go
new file mode 100644
index 00000000..4387a2ee
--- /dev/null
+++ b/backend/internal/pkg/websearch/manager_test.go
@@ -0,0 +1,149 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewManager_SortsByPriority(t *testing.T) {
+ configs := []ProviderConfig{
+ {Type: "brave", APIKey: "k3", Priority: 30},
+ {Type: "tavily", APIKey: "k1", Priority: 10},
+ }
+ m := NewManager(configs, nil)
+ require.Equal(t, 10, m.configs[0].Priority)
+ require.Equal(t, 30, m.configs[1].Priority)
+}
+
+func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
+ m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
+ require.ErrorContains(t, err, "empty search query")
+
+ _, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
+ require.ErrorContains(t, err, "empty search query")
+}
+
+func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
+ m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "no available provider")
+}
+
+func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
+ past := time.Now().Add(-1 * time.Hour).Unix()
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k", ExpiresAt: &past},
+ }, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "no available provider")
+}
+
+func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
+ // Create two mock servers that return different results
+ srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srvBrave.Close()
+
+ // Override brave endpoint for test
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srvBrave.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k1", Priority: 1},
+ {Type: "tavily", APIKey: "k2", Priority: 2},
+ }, nil)
+ // Inject the test server's client
+ m.clientCache[srvBrave.URL] = srvBrave.Client()
+ m.clientCache[""] = srvBrave.Client()
+
+ resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Equal(t, "brave", providerName)
+ require.Len(t, resp.Results, 1)
+ require.Equal(t, "from brave", resp.Results[0].Snippet)
+}
+
+func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
+ // With nil Redis, quota check is skipped (always allowed)
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k", Priority: 1, QuotaLimit: 100},
+ }, nil) // nil Redis
+ m.clientCache[""] = srv.Client()
+
+ resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Len(t, resp.Results, 1)
+}
+
+func TestManager_GetUsage_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ used, err := m.GetUsage(context.Background(), "brave", "monthly")
+ require.NoError(t, err)
+ require.Equal(t, int64(0), used)
+}
+
+func TestManager_GetAllUsage_NilRedis(t *testing.T) {
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", QuotaRefreshInterval: "monthly"},
+ }, nil)
+ usage := m.GetAllUsage(context.Background())
+ require.Equal(t, int64(0), usage["brave"])
+}
+
+// --- Key/TTL helpers ---
+
+func TestQuotaTTL_Daily(t *testing.T) {
+ require.Equal(t, 24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshDaily))
+}
+
+func TestQuotaTTL_Weekly(t *testing.T) {
+ require.Equal(t, 7*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshWeekly))
+}
+
+func TestQuotaTTL_Monthly(t *testing.T) {
+ require.Equal(t, 31*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshMonthly))
+}
+
+func TestPeriodKey_Daily(t *testing.T) {
+ key := periodKey(QuotaRefreshDaily)
+ require.Regexp(t, `^\d{4}-\d{2}-\d{2}$`, key)
+}
+
+func TestPeriodKey_Weekly(t *testing.T) {
+ key := periodKey(QuotaRefreshWeekly)
+ require.Regexp(t, `^\d{4}-W\d{2}$`, key)
+}
+
+func TestPeriodKey_Monthly(t *testing.T) {
+ key := periodKey(QuotaRefreshMonthly)
+ require.Regexp(t, `^\d{4}-\d{2}$`, key)
+}
+
+func TestQuotaRedisKey_Format(t *testing.T) {
+ key := quotaRedisKey("brave", QuotaRefreshDaily)
+ require.Contains(t, key, "websearch:quota:brave:")
+}
diff --git a/backend/internal/pkg/websearch/provider.go b/backend/internal/pkg/websearch/provider.go
new file mode 100644
index 00000000..3424c056
--- /dev/null
+++ b/backend/internal/pkg/websearch/provider.go
@@ -0,0 +1,11 @@
+package websearch
+
+import "context"
+
+// Provider is the interface every search backend must implement.
+type Provider interface {
+ // Name returns the provider identifier ("brave" or "tavily").
+ Name() string
+ // Search executes a web search and returns results.
+ Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
+}
diff --git a/backend/internal/pkg/websearch/tavily.go b/backend/internal/pkg/websearch/tavily.go
new file mode 100644
index 00000000..6ac09edf
--- /dev/null
+++ b/backend/internal/pkg/websearch/tavily.go
@@ -0,0 +1,107 @@
+package websearch
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+const (
+ tavilySearchEndpoint = "https://api.tavily.com/search"
+ tavilyProviderName = "tavily"
+ tavilySearchDepthBasic = "basic"
+)
+
+// TavilyProvider implements web search via the Tavily Search API.
+type TavilyProvider struct {
+ apiKey string
+ httpClient *http.Client
+}
+
+// NewTavilyProvider creates a Tavily Search provider.
+// The caller is responsible for configuring the http.Client with proxy/timeouts.
+func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
+}
+
+func (t *TavilyProvider) Name() string { return tavilyProviderName }
+
+func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
+ maxResults := req.MaxResults
+ if maxResults <= 0 {
+ maxResults = defaultMaxResults
+ }
+
+ payload := tavilyRequest{
+ APIKey: t.apiKey,
+ Query: req.Query,
+ MaxResults: maxResults,
+ SearchDepth: tavilySearchDepthBasic,
+ }
+
+ bodyBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("tavily: encode request: %w", err)
+ }
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
+ if err != nil {
+ return nil, fmt.Errorf("tavily: build request: %w", err)
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+
+ resp, err := t.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("tavily: request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
+ if err != nil {
+ return nil, fmt.Errorf("tavily: read body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
+ }
+
+ var raw tavilyResponse
+ if err := json.Unmarshal(body, &raw); err != nil {
+ return nil, fmt.Errorf("tavily: decode response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(raw.Results))
+ for _, r := range raw.Results {
+ results = append(results, SearchResult{
+ URL: r.URL,
+ Title: r.Title,
+ Snippet: r.Content,
+ })
+ }
+
+ return &SearchResponse{Results: results, Query: req.Query}, nil
+}
+
+type tavilyRequest struct {
+ APIKey string `json:"api_key"`
+ Query string `json:"query"`
+ MaxResults int `json:"max_results"`
+ SearchDepth string `json:"search_depth"`
+}
+
+type tavilyResponse struct {
+ Results []tavilyResult `json:"results"`
+}
+
+type tavilyResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Score float64 `json:"score"`
+}
diff --git a/backend/internal/pkg/websearch/tavily_test.go b/backend/internal/pkg/websearch/tavily_test.go
new file mode 100644
index 00000000..e1b6819a
--- /dev/null
+++ b/backend/internal/pkg/websearch/tavily_test.go
@@ -0,0 +1,63 @@
+package websearch
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTavilyProvider_Name(t *testing.T) {
+ p := NewTavilyProvider("key", nil)
+ require.Equal(t, "tavily", p.Name())
+}
+
+func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
+ // Verify tavilyRequest struct fields map correctly
+ req := tavilyRequest{
+ APIKey: "test-key",
+ Query: "golang",
+ MaxResults: 3,
+ SearchDepth: tavilySearchDepthBasic,
+ }
+ data, err := json.Marshal(req)
+ require.NoError(t, err)
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(data, &parsed))
+ require.Equal(t, "test-key", parsed["api_key"])
+ require.Equal(t, "golang", parsed["query"])
+ require.Equal(t, float64(3), parsed["max_results"])
+ require.Equal(t, "basic", parsed["search_depth"])
+}
+
+func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
+ rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
+ var resp tavilyResponse
+ require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
+ require.Len(t, resp.Results, 1)
+ require.Equal(t, "https://go.dev", resp.Results[0].URL)
+ require.Equal(t, "Go programming language", resp.Results[0].Content)
+ require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
+
+ // Verify mapping to SearchResult
+ results := make([]SearchResult, 0, len(resp.Results))
+ for _, r := range resp.Results {
+ results = append(results, SearchResult{
+ URL: r.URL, Title: r.Title, Snippet: r.Content,
+ })
+ }
+ require.Equal(t, "Go programming language", results[0].Snippet)
+ require.Equal(t, "", results[0].PageAge)
+}
+
+func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
+ var resp tavilyResponse
+ require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
+ require.Empty(t, resp.Results)
+}
+
+func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
+ var resp tavilyResponse
+ require.Error(t, json.Unmarshal([]byte("not json"), &resp))
+}
diff --git a/backend/internal/pkg/websearch/types.go b/backend/internal/pkg/websearch/types.go
new file mode 100644
index 00000000..bb489690
--- /dev/null
+++ b/backend/internal/pkg/websearch/types.go
@@ -0,0 +1,30 @@
+package websearch
+
+// SearchResult represents a single web search result.
+type SearchResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Snippet string `json:"snippet"`
+ PageAge string `json:"page_age,omitempty"`
+}
+
+// SearchRequest describes a web search to perform.
+type SearchRequest struct {
+ Query string
+ MaxResults int // defaults to defaultMaxResults if <= 0
+ ProxyURL string // optional HTTP proxy URL
+}
+
+// SearchResponse holds the results of a web search.
+type SearchResponse struct {
+ Results []SearchResult
+ Query string // the query that was actually executed
+}
+
+const defaultMaxResults = 5
+
+// Provider type identifiers.
+const (
+ ProviderTypeBrave = "brave"
+ ProviderTypeTavily = "tavily"
+)
diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go
index baad31f7..56b5cc71 100644
--- a/backend/internal/repository/channel_repo.go
+++ b/backend/internal/repository/channel_repo.go
@@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
+ featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
+ if err != nil {
+ return err
+ }
err = tx.QueryRowContext(ctx,
- `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features) VALUES ($1, $2, $3, $4, $5, $6, $7)
+ `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at, updated_at`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -73,11 +77,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
- var modelMappingJSON []byte
+ var modelMappingJSON, featuresConfigJSON []byte
err := r.db.QueryRowContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
FROM channels WHERE id = $1`, id,
- ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt)
+ ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
@@ -85,6 +89,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
@@ -107,10 +112,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
+ featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
+ if err != nil {
+ return err
+ }
result, err := tx.ExecContext(ctx,
- `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, updated_at = NOW()
- WHERE id = $8`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ID,
+ `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
+ WHERE id = $9`,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -187,9 +196,9 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
- `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.created_at, c.updated_at
- FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
- whereClause, argIdx, argIdx+1,
+ `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
+ FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
+ whereClause, channelListOrderBy(params), argIdx, argIdx+1,
)
args = append(args, pageSize, offset)
@@ -203,11 +212,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64
for rows.Next() {
var ch service.Channel
- var modelMappingJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ var modelMappingJSON, featuresConfigJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -246,9 +256,34 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
return channels, paginationResult, nil
}
+func channelListOrderBy(params pagination.PaginationParams) string {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
+
+ var column string
+ switch sortBy {
+ case "":
+ column = "c.id"
+ sortOrder = "ASC"
+ case "id":
+ column = "c.id"
+ case "name":
+ column = "c.name"
+ case "status":
+ column = "c.status"
+ case "created_at":
+ column = "c.created_at"
+ default:
+ column = "c.id"
+ sortOrder = "ASC"
+ }
+
+ return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
+}
+
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`,
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -259,11 +294,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64
for rows.Next() {
var ch service.Channel
- var modelMappingJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ var modelMappingJSON, featuresConfigJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -431,6 +467,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
return m
}
+func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
+ if len(m) == 0 {
+ return []byte("{}"), nil
+ }
+ data, err := json.Marshal(m)
+ if err != nil {
+ return nil, fmt.Errorf("marshal features_config: %w", err)
+ }
+ return data, nil
+}
+
+func unmarshalFeaturesConfig(data []byte) map[string]any {
+ if len(data) == 0 {
+ return nil
+ }
+ var m map[string]any
+ if err := json.Unmarshal(data, &m); err != nil {
+ return nil
+ }
+ return m
+}
+
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index b921da95..7c4e6cb7 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -407,6 +407,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
+ // Web Search 模拟配置
+ adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
+ adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
}
}
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 512195e3..582b136c 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -969,7 +969,7 @@ func (a *Account) IsOveragesEnabled() bool {
return false
}
-// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
+// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。
//
// 新字段:accounts.extra.openai_passthrough。
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
@@ -1133,7 +1133,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return resolvedDefault
}
-// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
+// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
// 字段:accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
@@ -1158,7 +1158,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
}
-// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
+// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。
// 字段:accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
@@ -1169,7 +1169,18 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled
}
-// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
+// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。
+// 字段:accounts.extra.web_search_emulation。
+// 字段缺失或类型不正确时,按 false(关闭)处理。
+func (a *Account) IsWebSearchEmulationEnabled() bool {
+ if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
+ return false
+ }
+ enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
+ return ok && enabled
+}
+
+// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
// 字段:accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool {
diff --git a/backend/internal/service/account_websearch_test.go b/backend/internal/service/account_websearch_test.go
new file mode 100644
index 00000000..fe742ebf
--- /dev/null
+++ b/backend/internal/service/account_websearch_test.go
@@ -0,0 +1,71 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ }
+ require.True(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: false},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "true"},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) {
+ a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil}
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) {
+ var a *Account
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) {
+ a := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index eac81444..baf5c839 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -49,6 +49,21 @@ type Channel struct {
ModelPricing []ChannelModelPricing
// 渠道级模型映射(按平台分组:platform → {src→dst})
ModelMapping map[string]map[string]string
+ // 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
+ FeaturesConfig map[string]any
+}
+
+// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
+func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
+ if c == nil || c.FeaturesConfig == nil {
+ return false
+ }
+ wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
+ if !ok {
+ return false
+ }
+ enabled, ok := wse[platform].(bool)
+ return ok && enabled
}
// ChannelModelPricing 渠道模型定价条目
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index cdf94a4c..7b28662b 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -197,10 +197,8 @@ func newEmptyChannelCache() *channelCache {
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
-// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
-// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
-// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
-// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
+// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
+// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
@@ -226,8 +224,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
-// antigravity 平台同时服务 Claude 和 Gemini 模型。
-// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
+// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
@@ -251,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
}
+// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
+// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
+func (s *ChannelService) storeErrorCache() {
+ errorCache := newEmptyChannelCache()
+ errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
+ s.cache.Store(errorCache)
+}
+
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
- // 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel()
- channels, err := s.repo.ListAll(dbCtx)
+ channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
if err != nil {
- // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
- slog.Warn("failed to build channel cache", "error", err)
- errorCache := newEmptyChannelCache()
- errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
- s.cache.Store(errorCache)
- return nil, fmt.Errorf("list all channels: %w", err)
+ return nil, err
+ }
+
+ cache := populateChannelCache(channels, groupPlatforms)
+ s.cache.Store(cache)
+ return cache, nil
+}
+
+// fetchChannelData 从数据库加载渠道列表和分组平台映射。
+func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
+ channels, err := s.repo.ListAll(ctx)
+ if err != nil {
+ slog.Warn("failed to build channel cache", "error", err)
+ s.storeErrorCache()
+ return nil, nil, fmt.Errorf("list all channels: %w", err)
}
- // 收集所有 groupID,批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
+
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
- groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
+ groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
- errorCache := newEmptyChannelCache()
- errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
- s.cache.Store(errorCache)
- return nil, fmt.Errorf("get group platforms: %w", err)
+ s.storeErrorCache()
+ return nil, nil, fmt.Errorf("get group platforms: %w", err)
}
}
+ return channels, groupPlatforms, nil
+}
+// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
+func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels))
@@ -293,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
-
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid]
@@ -302,32 +316,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
}
- // 通配符条目保持配置顺序(最先匹配到优先)
-
- s.cache.Store(cache)
- return cache, nil
+ return cache
}
// invalidateCache 使缓存失效,让下次读取时自然重建
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
-// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
-// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
+// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
- if groupPlatform == pricingPlatform {
- return true
- }
- if groupPlatform == PlatformAntigravity {
- return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
- }
- return false
+ return groupPlatform == pricingPlatform
}
-// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
+// matchingPlatforms 返回分组平台对应的可匹配平台列表。
+// 各平台严格独立,只返回自身。
func matchingPlatforms(groupPlatform string) []string {
- if groupPlatform == PlatformAntigravity {
- return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
- }
return []string{groupPlatform}
}
func (s *ChannelService) invalidateCache() {
@@ -364,10 +366,8 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
return ""
}
-// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
-// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
-// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
-// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
+// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
+// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
@@ -384,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
return nil
}
-// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
+// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
@@ -442,8 +442,7 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
-// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
-// 确保跨平台同名模型各自独立匹配。
+// 各平台严格独立,只在本平台内查找定价。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
@@ -481,7 +480,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
- lk, _ := s.lookupGroupChannel(ctx, groupID)
+ lk, err := s.lookupGroupChannel(ctx, groupID)
+ if err != nil {
+ slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
+ }
if lk == nil {
return false
}
@@ -524,7 +526,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
-// antigravity 分组依次尝试所有匹配平台的定价列表。
+// 只在本平台的定价列表中查找。
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels {
return false
@@ -552,6 +554,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
return newBody
}
+// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
+// Create 和 Update 共用此函数,避免重复。
+func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
+ if err := validateNoConflictingModels(pricing); err != nil {
+ return err
+ }
+ if err := validatePricingIntervals(pricing); err != nil {
+ return err
+ }
+ if err := validateNoConflictingMappings(mapping); err != nil {
+ return err
+ }
+ return validatePricingBillingMode(pricing)
+}
+
+// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
+func validatePricingBillingMode(pricing []ChannelModelPricing) error {
+ for _, p := range pricing {
+ if err := checkBillingModeRequirements(p); err != nil {
+ return err
+ }
+ if err := checkPricesNotNegative(p); err != nil {
+ return err
+ }
+ if err := checkIntervalsHavePrices(p); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func checkBillingModeRequirements(p ChannelModelPricing) error {
+ if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
+ if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
+ return infraerrors.BadRequest(
+ "BILLING_MODE_MISSING_PRICE",
+ "per-request price or intervals required for per_request/image billing mode",
+ )
+ }
+ }
+ return nil
+}
+
+func checkPricesNotNegative(p ChannelModelPricing) error {
+ checks := []struct {
+ field string
+ val *float64
+ }{
+ {"input_price", p.InputPrice},
+ {"output_price", p.OutputPrice},
+ {"cache_write_price", p.CacheWritePrice},
+ {"cache_read_price", p.CacheReadPrice},
+ {"image_output_price", p.ImageOutputPrice},
+ {"per_request_price", p.PerRequestPrice},
+ }
+ for _, c := range checks {
+ if c.val != nil && *c.val < 0 {
+ return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
+ }
+ }
+ return nil
+}
+
+func checkIntervalsHavePrices(p ChannelModelPricing) error {
+ for _, iv := range p.Intervals {
+ if iv.InputPrice == nil && iv.OutputPrice == nil &&
+ iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
+ iv.PerRequestPrice == nil {
+ return infraerrors.BadRequest(
+ "INTERVAL_MISSING_PRICE",
+ fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
+ iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
+ )
+ }
+ }
+ return nil
+}
+
+func formatMaxTokens(max *int) string {
+ if max == nil {
+ return "∞"
+ }
+ return fmt.Sprintf("%d", *max)
+}
+
// --- CRUD ---
// Create 创建渠道
@@ -564,15 +651,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
return nil, ErrChannelExists
}
- // 检查分组冲突
- if len(input.GroupIDs) > 0 {
- conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
- if err != nil {
- return nil, fmt.Errorf("check group conflicts: %w", err)
- }
- if len(conflicting) > 0 {
- return nil, ErrGroupAlreadyInChannel
- }
+ if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
+ return nil, err
}
channel := &Channel{
@@ -585,18 +665,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
Features: input.Features,
+ FeaturesConfig: input.FeaturesConfig,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
}
- if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
- return nil, err
- }
- if err := validatePricingIntervals(channel.ModelPricing); err != nil {
- return nil, err
- }
- if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
+ if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
@@ -620,105 +695,118 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, fmt.Errorf("get channel: %w", err)
}
- if input.Name != "" && input.Name != channel.Name {
- exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
- if err != nil {
- return nil, fmt.Errorf("check channel exists: %w", err)
- }
- if exists {
- return nil, ErrChannelExists
- }
- channel.Name = input.Name
- }
-
- if input.Description != nil {
- channel.Description = *input.Description
- }
-
- if input.Status != "" {
- channel.Status = input.Status
- }
-
- if input.RestrictModels != nil {
- channel.RestrictModels = *input.RestrictModels
- }
- if input.Features != nil {
- channel.Features = *input.Features
- }
-
- // 检查分组冲突
- if input.GroupIDs != nil {
- conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
- if err != nil {
- return nil, fmt.Errorf("check group conflicts: %w", err)
- }
- if len(conflicting) > 0 {
- return nil, ErrGroupAlreadyInChannel
- }
- channel.GroupIDs = *input.GroupIDs
- }
-
- if input.ModelPricing != nil {
- channel.ModelPricing = *input.ModelPricing
- }
-
- if input.ModelMapping != nil {
- channel.ModelMapping = input.ModelMapping
- }
-
- if input.BillingModelSource != "" {
- channel.BillingModelSource = input.BillingModelSource
- }
-
- if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
- return nil, err
- }
- if err := validatePricingIntervals(channel.ModelPricing); err != nil {
- return nil, err
- }
- if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
+ if err := s.applyUpdateInput(ctx, channel, input); err != nil {
return nil, err
}
- // 先获取旧分组,Update 后旧分组关联已删除,无法再查到
- var oldGroupIDs []int64
- if s.authCacheInvalidator != nil {
- var err2 error
- oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
- if err2 != nil {
- slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
- }
+ if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
+ return nil, err
}
+ oldGroupIDs := s.getOldGroupIDs(ctx, id)
+
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
-
- // 失效新旧分组的 auth 缓存
- if s.authCacheInvalidator != nil {
- seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
- for _, gid := range oldGroupIDs {
- if _, ok := seen[gid]; !ok {
- seen[gid] = struct{}{}
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
- }
- }
- for _, gid := range channel.GroupIDs {
- if _, ok := seen[gid]; !ok {
- seen[gid] = struct{}{}
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
- }
- }
- }
+ s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
return s.repo.GetByID(ctx, id)
}
+// applyUpdateInput 将更新请求的字段应用到渠道实体上。
+func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
+ if input.Name != "" && input.Name != channel.Name {
+ exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
+ if err != nil {
+ return fmt.Errorf("check channel exists: %w", err)
+ }
+ if exists {
+ return ErrChannelExists
+ }
+ channel.Name = input.Name
+ }
+ if input.Description != nil {
+ channel.Description = *input.Description
+ }
+ if input.Status != "" {
+ channel.Status = input.Status
+ }
+ if input.RestrictModels != nil {
+ channel.RestrictModels = *input.RestrictModels
+ }
+ if input.Features != nil {
+ channel.Features = *input.Features
+ }
+ if input.GroupIDs != nil {
+ if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
+ return err
+ }
+ channel.GroupIDs = *input.GroupIDs
+ }
+ if input.ModelPricing != nil {
+ channel.ModelPricing = *input.ModelPricing
+ }
+ if input.ModelMapping != nil {
+ channel.ModelMapping = input.ModelMapping
+ }
+ if input.BillingModelSource != "" {
+ channel.BillingModelSource = input.BillingModelSource
+ }
+ if input.FeaturesConfig != nil {
+ channel.FeaturesConfig = input.FeaturesConfig
+ }
+ return nil
+}
+
+// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
+// channelID 为当前渠道 ID(Create 时传 0)。
+func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
+ if len(groupIDs) == 0 {
+ return nil
+ }
+ conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
+ if err != nil {
+ return fmt.Errorf("check group conflicts: %w", err)
+ }
+ if len(conflicting) > 0 {
+ return ErrGroupAlreadyInChannel
+ }
+ return nil
+}
+
+// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
+func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
+ if s.authCacheInvalidator == nil {
+ return nil
+ }
+ oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
+ if err != nil {
+ slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
+ }
+ return oldGroupIDs
+}
+
+// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
+func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
+ if s.authCacheInvalidator == nil {
+ return
+ }
+ seen := make(map[int64]struct{})
+ for _, ids := range groupIDSets {
+ for _, gid := range ids {
+ if _, ok := seen[gid]; ok {
+ continue
+ }
+ seen[gid] = struct{}{}
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
+ }
+ }
+}
+
// Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
- // 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
@@ -729,12 +817,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
}
s.invalidateCache()
-
- if s.authCacheInvalidator != nil {
- for _, gid := range groupIDs {
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
- }
- }
+ s.invalidateAuthCacheForGroups(ctx, groupIDs)
return nil
}
@@ -847,6 +930,7 @@ type CreateChannelInput struct {
BillingModelSource string
RestrictModels bool
Features string
+ FeaturesConfig map[string]any
}
// UpdateChannelInput 更新渠道输入
@@ -860,4 +944,5 @@ type UpdateChannelInput struct {
BillingModelSource string
RestrictModels *bool
Features *string
+ FeaturesConfig map[string]any
}
diff --git a/backend/internal/service/channel_websearch_test.go b/backend/internal/service/channel_websearch_test.go
new file mode 100644
index 00000000..d3dbe45d
--- /dev/null
+++ b/backend/internal/service/channel_websearch_test.go
@@ -0,0 +1,62 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
+ },
+ }
+ require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
+ },
+ }
+ require.False(t, c.IsWebSearchEmulationEnabled("openai"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
+ },
+ }
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
+ c := &Channel{FeaturesConfig: nil}
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
+ var c *Channel
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: true, // not a map
+ },
+ }
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
+ },
+ }
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 68d7da3b..f43d388b 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -249,6 +249,10 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning = "enable_cch_signing"
+
+ // Web Search Emulation
+ // SettingKeyWebSearchEmulationConfig 全局 web search 模拟配置(JSON)
+ SettingKeyWebSearchEmulationConfig = "web_search_emulation_config"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 5d285fb6..77e9b8c8 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -3785,6 +3785,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request")
}
+ // Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
+ if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
+ return s.handleWebSearchEmulation(ctx, c, account, parsed)
+ }
+
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body
passthroughModel := parsed.Model
diff --git a/backend/internal/service/gateway_websearch_emulation.go b/backend/internal/service/gateway_websearch_emulation.go
new file mode 100644
index 00000000..fbea96c0
--- /dev/null
+++ b/backend/internal/service/gateway_websearch_emulation.go
@@ -0,0 +1,358 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ "github.com/tidwall/gjson"
+)
+
+// Web search emulation constants
+const (
+ toolTypeWebSearchPrefix = "web_search"
+ toolTypeGoogleSearch = "google_search"
+ toolNameWebSearch = "web_search"
+ toolNameGoogleSearch = "google_search"
+ toolNameWebSearch2025 = "web_search_20250305"
+
+ webSearchDefaultMaxResults = 5
+ defaultWebSearchModel = "claude-sonnet-4-6"
+ webSearchMsgIDPrefix = "msg_ws_"
+ webSearchToolUseIDPrefix = "srvtoolu_ws_"
+ tokenEstimateDivisor = 4
+
+ // featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
+ featureKeyWebSearchEmulation = "web_search_emulation"
+)
+
+// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
+var webSearchManagerPtr atomic.Pointer[websearch.Manager]
+
+// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
+func SetWebSearchManager(m *websearch.Manager) {
+ webSearchManagerPtr.Store(m)
+}
+
+func getWebSearchManager() *websearch.Manager {
+ return webSearchManagerPtr.Load()
+}
+
+// shouldEmulateWebSearch checks whether a request should be intercepted.
+//
+// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
+// Note: channel-level control is enforced via the account's extra field; the channel toggle
+// in the admin UI sets the account's flag for all accounts in that channel's groups.
+func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
+ if getWebSearchManager() == nil {
+ return false
+ }
+ if !isOnlyWebSearchToolInBody(body) {
+ return false
+ }
+ if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
+ return false
+ }
+ if !account.IsWebSearchEmulationEnabled() {
+ return false
+ }
+ return true
+}
+
+// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
+func isOnlyWebSearchToolInBody(body []byte) bool {
+ tools := gjson.GetBytes(body, "tools")
+ if !tools.IsArray() {
+ return false
+ }
+ arr := tools.Array()
+ if len(arr) != 1 {
+ return false
+ }
+ return isWebSearchToolJSON(arr[0])
+}
+
+func isWebSearchToolJSON(tool gjson.Result) bool {
+ toolType := tool.Get("type").String()
+ if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
+ return true
+ }
+ switch tool.Get("name").String() {
+ case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
+ return true
+ }
+ return false
+}
+
+// extractSearchQueryFromBody extracts the last user message text as the search query.
+func extractSearchQueryFromBody(body []byte) string {
+ messages := gjson.GetBytes(body, "messages")
+ if !messages.IsArray() {
+ return ""
+ }
+ arr := messages.Array()
+ if len(arr) == 0 {
+ return ""
+ }
+ lastMsg := arr[len(arr)-1]
+ if lastMsg.Get("role").String() != "user" {
+ return ""
+ }
+ return extractWebSearchTextFromContent(lastMsg.Get("content"))
+}
+
+func extractWebSearchTextFromContent(content gjson.Result) string {
+ if content.Type == gjson.String {
+ return content.String()
+ }
+ if content.IsArray() {
+ for _, block := range content.Array() {
+ if block.Get("type").String() == "text" {
+ if text := block.Get("text").String(); text != "" {
+ return text
+ }
+ }
+ }
+ }
+ return ""
+}
+
+// handleWebSearchEmulation intercepts a web-search-only request,
+// calls a third-party search API, and constructs an Anthropic-format response.
+func (s *GatewayService) handleWebSearchEmulation(
+ ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
+) (*ForwardResult, error) {
+ startTime := time.Now()
+
+ // Release the serial queue lock immediately — we don't need upstream.
+ if parsed.OnUpstreamAccepted != nil {
+ parsed.OnUpstreamAccepted()
+ }
+
+ query := extractSearchQueryFromBody(parsed.Body)
+ if query == "" {
+ return nil, fmt.Errorf("web search emulation: no query found in messages")
+ }
+
+ slog.Info("web search emulation: executing search",
+ "account_id", account.ID, "account_name", account.Name, "query", query)
+
+ resp, providerName, err := doWebSearch(ctx, account, query)
+ if err != nil {
+ return nil, err
+ }
+
+ slog.Info("web search emulation: search completed",
+ "provider", providerName, "results_count", len(resp.Results))
+
+ model := parsed.Model
+ if model == "" {
+ model = defaultWebSearchModel
+ }
+
+ if parsed.Stream {
+ return writeWebSearchStreamResponse(c, query, resp, model, startTime)
+ }
+ return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
+}
+
+func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
+ proxyURL := resolveAccountProxyURL(account)
+ mgr := getWebSearchManager()
+ if mgr == nil {
+ return nil, "", fmt.Errorf("web search emulation: manager not initialized")
+ }
+ resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
+ Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
+ })
+ if err != nil {
+ slog.Error("web search emulation: search failed", "error", err)
+ return nil, "", fmt.Errorf("web search emulation: %w", err)
+ }
+ return resp, providerName, nil
+}
+
+func resolveAccountProxyURL(account *Account) string {
+ if account.ProxyID != nil && account.Proxy != nil {
+ return account.Proxy.URL()
+ }
+ return ""
+}
+
+// --- SSE streaming response ---
+
+func writeWebSearchStreamResponse(
+ c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
+) (*ForwardResult, error) {
+ msgID := webSearchMsgIDPrefix + uuid.New().String()
+ toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
+
+ setSSEHeaders(c)
+ if err := writeSSEMessageStart(c.Writer, msgID, model); err != nil {
+ return nil, fmt.Errorf("web search emulation: SSE write: %w", err)
+ }
+ writeSSEServerToolUse(c.Writer, toolUseID, query, 0)
+ writeSSEToolResult(c.Writer, toolUseID, resp.Results, 1)
+ textSummary := buildTextSummary(query, resp.Results)
+ writeSSETextBlock(c.Writer, textSummary, 2)
+ writeSSEMessageEnd(c.Writer, len(textSummary)/tokenEstimateDivisor)
+ c.Writer.Flush()
+
+ return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
+}
+
+func setSSEHeaders(c *gin.Context) {
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.WriteHeader(http.StatusOK)
+}
+
+func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
+ evt := map[string]any{
+ "type": "message_start",
+ "message": map[string]any{
+ "id": msgID, "type": "message", "role": "assistant", "model": model,
+ "content": []any{}, "stop_reason": nil, "stop_sequence": nil,
+ "usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
+ },
+ }
+ return flushSSEJSON(w, "message_start", evt)
+}
+
+func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) {
+ start := map[string]any{
+ "type": "content_block_start", "index": index,
+ "content_block": map[string]any{
+ "type": "server_tool_use", "id": toolUseID,
+ "name": toolNameWebSearch, "input": map[string]string{"query": query},
+ },
+ }
+ _ = flushSSEJSON(w, "content_block_start", start)
+ _ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
+}
+
+func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) {
+ start := map[string]any{
+ "type": "content_block_start", "index": index,
+ "content_block": map[string]any{
+ "type": "web_search_tool_result", "tool_use_id": toolUseID,
+ "content": buildSearchResultBlocks(results),
+ },
+ }
+ _ = flushSSEJSON(w, "content_block_start", start)
+ _ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
+}
+
+func writeSSETextBlock(w http.ResponseWriter, text string, index int) {
+ _ = flushSSEJSON(w, "content_block_start", map[string]any{
+ "type": "content_block_start", "index": index,
+ "content_block": map[string]any{"type": "text", "text": ""},
+ })
+ _ = flushSSEJSON(w, "content_block_delta", map[string]any{
+ "type": "content_block_delta", "index": index,
+ "delta": map[string]string{"type": "text_delta", "text": text},
+ })
+ _ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
+}
+
+func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) {
+ _ = flushSSEJSON(w, "message_delta", map[string]any{
+ "type": "message_delta",
+ "delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
+ "usage": map[string]int{"output_tokens": outputTokens},
+ })
+ _ = flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
+}
+
+// flushSSEJSON marshals data to JSON and writes an SSE event. Returns error on marshal failure.
+func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
+ b, err := json.Marshal(data)
+ if err != nil {
+ slog.Error("web search emulation: failed to marshal SSE event",
+ "event", event, "error", err)
+ return err
+ }
+ fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b)
+ if f, ok := w.(http.Flusher); ok {
+ f.Flush()
+ }
+ return nil
+}
+
+// --- Non-streaming JSON response ---
+
+func writeWebSearchNonStreamResponse(
+ c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
+) (*ForwardResult, error) {
+ msgID := webSearchMsgIDPrefix + uuid.New().String()
+ toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
+ textSummary := buildTextSummary(query, resp.Results)
+
+ msg := map[string]any{
+ "id": msgID, "type": "message", "role": "assistant", "model": model,
+ "content": []any{
+ map[string]any{
+ "type": "server_tool_use", "id": toolUseID,
+ "name": toolNameWebSearch, "input": map[string]string{"query": query},
+ },
+ map[string]any{
+ "type": "web_search_tool_result", "tool_use_id": toolUseID,
+ "content": buildSearchResultBlocks(resp.Results),
+ },
+ map[string]any{"type": "text", "text": textSummary},
+ },
+ "stop_reason": "end_turn", "stop_sequence": nil,
+ "usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
+ }
+
+ body, err := json.Marshal(msg)
+ if err != nil {
+ return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
+ }
+ c.Data(http.StatusOK, "application/json", body)
+
+ return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
+}
+
+// --- Helpers ---
+
+func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
+ blocks := make([]map[string]string, 0, len(results))
+ for _, r := range results {
+ block := map[string]string{
+ "type": "web_search_result",
+ "url": r.URL,
+ "title": r.Title,
+ }
+ if r.Snippet != "" {
+ block["page_content"] = r.Snippet
+ }
+ if r.PageAge != "" {
+ block["page_age"] = r.PageAge
+ }
+ blocks = append(blocks, block)
+ }
+ return blocks
+}
+
+func buildTextSummary(query string, results []websearch.SearchResult) string {
+ if len(results) == 0 {
+ return "No search results found for: " + query
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
+ for i, r := range results {
+ fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
+ }
+ return sb.String()
+}
diff --git a/backend/internal/service/gateway_websearch_emulation_test.go b/backend/internal/service/gateway_websearch_emulation_test.go
new file mode 100644
index 00000000..b606c748
--- /dev/null
+++ b/backend/internal/service/gateway_websearch_emulation_test.go
@@ -0,0 +1,142 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
+ "github.com/stretchr/testify/require"
+)
+
+// --- isOnlyWebSearchToolInBody ---
+
+func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody(
+ []byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
+}
+
+// --- extractSearchQueryFromBody ---
+
+func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":"what is golang"}]}`
+ require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
+ require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
+ require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
+ require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
+ require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
+}
+
+func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
+ require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
+}
+
+func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
+ require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
+ require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
+}
+
+// --- buildSearchResultBlocks ---
+
+func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
+ results := []websearch.SearchResult{
+ {URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
+ {URL: "https://b.com", Title: "B", Snippet: "snippet b"},
+ }
+ blocks := buildSearchResultBlocks(results)
+ require.Len(t, blocks, 2)
+ require.Equal(t, "web_search_result", blocks[0]["type"])
+ require.Equal(t, "https://a.com", blocks[0]["url"])
+ require.Equal(t, "snippet a", blocks[0]["page_content"])
+ require.Equal(t, "2 days", blocks[0]["page_age"])
+ // Second result has no PageAge
+ require.Equal(t, "https://b.com", blocks[1]["url"])
+ _, hasPageAge := blocks[1]["page_age"]
+ require.False(t, hasPageAge)
+}
+
+func TestBuildSearchResultBlocks_Empty(t *testing.T) {
+ blocks := buildSearchResultBlocks(nil)
+ require.Empty(t, blocks)
+}
+
+func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
+ blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
+ _, hasContent := blocks[0]["page_content"]
+ require.False(t, hasContent)
+}
+
+// --- buildTextSummary ---
+
+func TestBuildTextSummary_WithResults(t *testing.T) {
+ results := []websearch.SearchResult{
+ {URL: "https://a.com", Title: "A", Snippet: "desc a"},
+ }
+ summary := buildTextSummary("test query", results)
+ require.Contains(t, summary, "test query")
+ require.Contains(t, summary, "1. **A**")
+ require.Contains(t, summary, "https://a.com")
+}
+
+func TestBuildTextSummary_NoResults(t *testing.T) {
+ summary := buildTextSummary("test", nil)
+ require.Contains(t, summary, "No search results found for: test")
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 48f25da0..3cfe5e56 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/imroc/req/v3"
+ "github.com/redis/go-redis/v9"
"golang.org/x/sync/singleflight"
)
@@ -106,6 +107,7 @@ type SettingService struct {
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
+ webSearchRedis *redis.Client // optional: Redis client for web search quota tracking
}
// NewSettingService 创建系统设置服务实例
@@ -1217,6 +1219,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
+ // Web search emulation: quick enabled check from the JSON config
+ if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
+ var wsCfg WebSearchEmulationConfig
+ if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
+ result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
+ }
+ }
+
return result
}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index de92b796..f5535bca 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -106,6 +106,9 @@ type SystemSettings struct {
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
+
+ // Web Search Emulation (read-only quick check; full config via dedicated API)
+ WebSearchEmulationEnabled bool
}
type DefaultSubscriptionSetting struct {
diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go
new file mode 100644
index 00000000..15ec1f9d
--- /dev/null
+++ b/backend/internal/service/websearch_config.go
@@ -0,0 +1,253 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "sync/atomic"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
+ "github.com/redis/go-redis/v9"
+ "golang.org/x/sync/singleflight"
+)
+
+// WebSearchEmulationConfig holds the global web search emulation configuration.
+type WebSearchEmulationConfig struct {
+ Enabled bool `json:"enabled"`
+ Providers []WebSearchProviderConfig `json:"providers"`
+}
+
+// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
+type WebSearchProviderConfig struct {
+ Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
+ APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
+ APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
+ Priority int `json:"priority"` // lower = higher priority
+ QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
+ QuotaRefreshInterval string `json:"quota_refresh_interval"` // websearch.QuotaRefresh*
+ QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current period usage
+ ProxyID *int64 `json:"proxy_id"` // optional proxy association
+ ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
+}
+
+// --- Validation ---
+
+const maxWebSearchProviders = 10
+
+var validProviderTypes = map[string]bool{
+ websearch.ProviderTypeBrave: true,
+ websearch.ProviderTypeTavily: true,
+}
+
+var validQuotaIntervals = map[string]bool{
+ websearch.QuotaRefreshDaily: true,
+ websearch.QuotaRefreshWeekly: true,
+ websearch.QuotaRefreshMonthly: true,
+ "": true, // defaults to monthly
+}
+
+func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
+ if cfg == nil {
+ return nil
+ }
+ if len(cfg.Providers) > maxWebSearchProviders {
+ return fmt.Errorf("too many providers (max %d)", maxWebSearchProviders)
+ }
+ seen := make(map[string]bool, len(cfg.Providers))
+ for i, p := range cfg.Providers {
+ if !validProviderTypes[p.Type] {
+ return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type)
+ }
+ if !validQuotaIntervals[p.QuotaRefreshInterval] {
+ return fmt.Errorf("provider[%d]: invalid quota_refresh_interval %q", i, p.QuotaRefreshInterval)
+ }
+ if p.QuotaLimit < 0 {
+ return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i)
+ }
+ if seen[p.Type] {
+ return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type)
+ }
+ seen[p.Type] = true
+ }
+ return nil
+}
+
+// --- In-process cache (same pattern as gateway forwarding settings) ---
+
+const sfKeyWebSearchConfig = "web_search_emulation_config"
+
+type cachedWebSearchEmulationConfig struct {
+ config *WebSearchEmulationConfig
+ expiresAt int64 // unix nano
+}
+
+var webSearchEmulationCache atomic.Value // *cachedWebSearchEmulationConfig
+var webSearchEmulationSF singleflight.Group
+
+const (
+ webSearchEmulationCacheTTL = 60 * time.Second
+ webSearchEmulationErrorTTL = 5 * time.Second
+ webSearchEmulationDBTimeout = 5 * time.Second
+)
+
+// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
+func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
+ if cached := webSearchEmulationCache.Load(); cached != nil {
+ c := cached.(*cachedWebSearchEmulationConfig)
+ if time.Now().UnixNano() < c.expiresAt {
+ return c.config, nil
+ }
+ }
+ result, err, _ := webSearchEmulationSF.Do(sfKeyWebSearchConfig, func() (any, error) {
+ return s.loadWebSearchConfigFromDB()
+ })
+ if err != nil {
+ return &WebSearchEmulationConfig{}, err
+ }
+ return result.(*WebSearchEmulationConfig), nil
+}
+
+func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
+ dbCtx, cancel := context.WithTimeout(context.Background(), webSearchEmulationDBTimeout)
+ defer cancel()
+
+ raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyWebSearchEmulationConfig)
+ if err != nil {
+ webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
+ config: &WebSearchEmulationConfig{},
+ expiresAt: time.Now().Add(webSearchEmulationErrorTTL).UnixNano(),
+ })
+ return &WebSearchEmulationConfig{}, err
+ }
+ cfg := parseWebSearchConfigJSON(raw)
+ webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
+ config: cfg,
+ expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
+ })
+ return cfg, nil
+}
+
+func parseWebSearchConfigJSON(raw string) *WebSearchEmulationConfig {
+ cfg := &WebSearchEmulationConfig{}
+ if raw == "" {
+ return cfg
+ }
+ if err := json.Unmarshal([]byte(raw), cfg); err != nil {
+ slog.Warn("websearch: failed to parse config JSON", "error", err)
+ return &WebSearchEmulationConfig{}
+ }
+ return cfg
+}
+
+// SaveWebSearchEmulationConfig validates and persists the configuration.
+// Empty API keys in the input are preserved from the existing config.
+func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *WebSearchEmulationConfig) error {
+ if err := validateWebSearchConfig(cfg); err != nil {
+ return infraerrors.BadRequest("INVALID_WEB_SEARCH_CONFIG", err.Error())
+ }
+ s.mergeExistingAPIKeys(ctx, cfg)
+
+ data, err := json.Marshal(cfg)
+ if err != nil {
+ return fmt.Errorf("websearch: marshal config: %w", err)
+ }
+ if err := s.settingRepo.Set(ctx, SettingKeyWebSearchEmulationConfig, string(data)); err != nil {
+ return fmt.Errorf("websearch: save config: %w", err)
+ }
+ // Invalidate: forget singleflight first, then store new value
+ webSearchEmulationSF.Forget(sfKeyWebSearchConfig)
+ webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
+ config: cfg,
+ expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
+ })
+
+ // Hot-reload: rebuild the global Manager with new config
+ s.RebuildWebSearchManager(ctx)
+ return nil
+}
+
+// mergeExistingAPIKeys preserves API keys from the current config when incoming value is empty.
+func (s *SettingService) mergeExistingAPIKeys(ctx context.Context, cfg *WebSearchEmulationConfig) {
+ existing, _ := s.getWebSearchEmulationConfigRaw(ctx)
+ if existing == nil || cfg == nil {
+ return
+ }
+ existingByType := make(map[string]string, len(existing.Providers))
+ for _, p := range existing.Providers {
+ if p.APIKey != "" {
+ existingByType[p.Type] = p.APIKey
+ }
+ }
+ for i := range cfg.Providers {
+ if cfg.Providers[i].APIKey == "" {
+ if key, ok := existingByType[cfg.Providers[i].Type]; ok {
+ cfg.Providers[i].APIKey = key
+ }
+ }
+ }
+}
+
+func (s *SettingService) getWebSearchEmulationConfigRaw(ctx context.Context) (*WebSearchEmulationConfig, error) {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyWebSearchEmulationConfig)
+ if err != nil {
+ return nil, err
+ }
+ return parseWebSearchConfigJSON(raw), nil
+}
+
+// IsWebSearchEmulationEnabled is a quick check for whether the global switch is on.
+func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
+ cfg, err := s.GetWebSearchEmulationConfig(ctx)
+ if err != nil {
+ return false
+ }
+ return cfg.Enabled && len(cfg.Providers) > 0
+}
+
+// SetWebSearchRedisClient injects the Redis client used for quota tracking.
+// Call after construction, before first use. Triggers initial Manager build.
+func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) {
+ s.webSearchRedis = redisClient
+ s.RebuildWebSearchManager(ctx)
+}
+
+// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager.
+// Called on startup and after SaveWebSearchEmulationConfig.
+func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
+ cfg, err := s.GetWebSearchEmulationConfig(ctx)
+ if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 {
+ SetWebSearchManager(nil)
+ return
+ }
+ providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
+ for _, p := range cfg.Providers {
+ providerConfigs = append(providerConfigs, websearch.ProviderConfig{
+ Type: p.Type,
+ APIKey: p.APIKey,
+ Priority: p.Priority,
+ QuotaLimit: p.QuotaLimit,
+ QuotaRefreshInterval: p.QuotaRefreshInterval,
+ ExpiresAt: p.ExpiresAt,
+ })
+ }
+ SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis))
+ slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
+}
+
+// SanitizeWebSearchConfig returns a copy with api_key fields masked for API responses.
+func SanitizeWebSearchConfig(cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
+ if cfg == nil {
+ return nil
+ }
+ out := *cfg
+ out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
+ for i, p := range cfg.Providers {
+ out.Providers[i] = p
+ out.Providers[i].APIKeyConfigured = p.APIKey != ""
+ out.Providers[i].APIKey = "" // never return the secret
+ }
+ return &out
+}
diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go
new file mode 100644
index 00000000..1a19dd9d
--- /dev/null
+++ b/backend/internal/service/websearch_config_test.go
@@ -0,0 +1,148 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// --- validateWebSearchConfig ---
+
+func TestValidateWebSearchConfig_Nil(t *testing.T) {
+ require.NoError(t, validateWebSearchConfig(nil))
+}
+
+func TestValidateWebSearchConfig_Valid(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Enabled: true,
+ Providers: []WebSearchProviderConfig{
+ {Type: "brave", Priority: 1, QuotaLimit: 1000, QuotaRefreshInterval: "monthly"},
+ {Type: "tavily", Priority: 2, QuotaLimit: 500, QuotaRefreshInterval: "daily"},
+ },
+ }
+ require.NoError(t, validateWebSearchConfig(cfg))
+}
+
+func TestValidateWebSearchConfig_TooManyProviders(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{Providers: make([]WebSearchProviderConfig, 11)}
+ for i := range cfg.Providers {
+ cfg.Providers[i] = WebSearchProviderConfig{Type: "brave"}
+ }
+ err := validateWebSearchConfig(cfg)
+ require.ErrorContains(t, err, "too many providers")
+}
+
+func TestValidateWebSearchConfig_InvalidType(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "bing"}},
+ }
+ require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type")
+}
+
+func TestValidateWebSearchConfig_InvalidQuotaInterval(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: "hourly"}},
+ }
+ require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid quota_refresh_interval")
+}
+
+func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}},
+ }
+ require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be >= 0")
+}
+
+func TestValidateWebSearchConfig_DuplicateType(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{
+ {Type: "brave", Priority: 1},
+ {Type: "brave", Priority: 2},
+ },
+ }
+ require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type")
+}
+
+func TestValidateWebSearchConfig_EmptyQuotaInterval(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: ""}},
+ }
+ require.NoError(t, validateWebSearchConfig(cfg))
+}
+
+func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}},
+ }
+ require.NoError(t, validateWebSearchConfig(cfg))
+}
+
+// --- parseWebSearchConfigJSON ---
+
+func TestParseWebSearchConfigJSON_ValidJSON(t *testing.T) {
+ raw := `{"enabled":true,"providers":[{"type":"brave","api_key":"sk-xxx"}]}`
+ cfg := parseWebSearchConfigJSON(raw)
+ require.True(t, cfg.Enabled)
+ require.Len(t, cfg.Providers, 1)
+ require.Equal(t, "brave", cfg.Providers[0].Type)
+}
+
+func TestParseWebSearchConfigJSON_EmptyString(t *testing.T) {
+ cfg := parseWebSearchConfigJSON("")
+ require.False(t, cfg.Enabled)
+ require.Empty(t, cfg.Providers)
+}
+
+func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) {
+ cfg := parseWebSearchConfigJSON("not{json")
+ require.False(t, cfg.Enabled)
+ require.Empty(t, cfg.Providers)
+}
+
+// --- SanitizeWebSearchConfig ---
+
+func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Enabled: true,
+ Providers: []WebSearchProviderConfig{
+ {Type: "brave", APIKey: "sk-secret-xxx"},
+ },
+ }
+ out := SanitizeWebSearchConfig(cfg)
+ require.Equal(t, "", out.Providers[0].APIKey)
+ require.True(t, out.Providers[0].APIKeyConfigured)
+}
+
+func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}},
+ }
+ out := SanitizeWebSearchConfig(cfg)
+ require.Equal(t, "", out.Providers[0].APIKey)
+ require.False(t, out.Providers[0].APIKeyConfigured)
+}
+
+func TestSanitizeWebSearchConfig_Nil(t *testing.T) {
+ require.Nil(t, SanitizeWebSearchConfig(nil))
+}
+
+func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Enabled: true,
+ Providers: []WebSearchProviderConfig{
+ {Type: "brave", APIKey: "secret", Priority: 10, QuotaLimit: 1000},
+ },
+ }
+ out := SanitizeWebSearchConfig(cfg)
+ require.True(t, out.Enabled)
+ require.Equal(t, 10, out.Providers[0].Priority)
+ require.Equal(t, int64(1000), out.Providers[0].QuotaLimit)
+}
+
+func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}},
+ }
+ _ = SanitizeWebSearchConfig(cfg)
+ require.Equal(t, "secret", cfg.Providers[0].APIKey)
+}
diff --git a/backend/migrations/101_add_channel_features_config.sql b/backend/migrations/101_add_channel_features_config.sql
new file mode 100644
index 00000000..b054b085
--- /dev/null
+++ b/backend/migrations/101_add_channel_features_config.sql
@@ -0,0 +1,2 @@
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS features_config JSONB NOT NULL DEFAULT '{}';
+COMMENT ON COLUMN channels.features_config IS '渠道特性配置(如 web_search_emulation),JSON 对象格式';
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
index b3455022..d49982aa 100644
--- a/frontend/src/api/admin/channels.ts
+++ b/frontend/src/api/admin/channels.ts
@@ -41,6 +41,7 @@ export interface Channel {
status: string
billing_model_source: string // "requested" | "upstream"
restrict_models: boolean
+ features_config?: Record
group_ids: number[]
model_pricing: ChannelModelPricing[]
model_mapping: Record> // platform → {src→dst}
@@ -56,6 +57,7 @@ export interface CreateChannelRequest {
model_mapping?: Record>
billing_model_source?: string
restrict_models?: boolean
+ features_config?: Record
}
export interface UpdateChannelRequest {
@@ -67,6 +69,7 @@ export interface UpdateChannelRequest {
model_mapping?: Record>
billing_model_source?: string
restrict_models?: boolean
+ features_config?: Record
}
interface PaginatedResponse {
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 504abe9c..7fc6c852 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -482,6 +482,42 @@ export async function updateBetaPolicySettings(
return data
}
+// --- Web Search Emulation Config ---
+
+export interface WebSearchProviderConfig {
+ type: 'brave' | 'tavily'
+ api_key: string
+ api_key_configured: boolean
+ priority: number
+ quota_limit: number
+ quota_refresh_interval: 'daily' | 'weekly' | 'monthly'
+ quota_used?: number
+ proxy_id: number | null
+ expires_at: number | null
+}
+
+export interface WebSearchEmulationConfig {
+ enabled: boolean
+ providers: WebSearchProviderConfig[]
+}
+
+export async function getWebSearchEmulationConfig(): Promise {
+ const { data } = await apiClient.get(
+ '/admin/settings/web-search-emulation'
+ )
+ return data
+}
+
+export async function updateWebSearchEmulationConfig(
+ config: WebSearchEmulationConfig
+): Promise {
+ const { data } = await apiClient.put(
+ '/admin/settings/web-search-emulation',
+ config
+ )
+ return data
+}
+
export const settingsAPI = {
getSettings,
updateSettings,
@@ -497,7 +533,9 @@ export const settingsAPI = {
getRectifierSettings,
updateRectifierSettings,
getBetaPolicySettings,
- updateBetaPolicySettings
+ updateBetaPolicySettings,
+ getWebSearchEmulationConfig,
+ updateWebSearchEmulationConfig
}
export default settingsAPI
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index 380201c4..d5d8ff89 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -2325,6 +2325,22 @@
+
+
+
+
+
{{ t('admin.accounts.anthropic.webSearchEmulation') }}
+
+ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}
+
+
+
+
+
+
(OPENAI_WS_MODE_OFF
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
+const webSearchEmulationEnabled = ref(false)
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
@@ -3307,6 +3325,7 @@ watch(
}
if (newPlatform !== 'anthropic') {
anthropicPassthroughEnabled.value = false
+ webSearchEmulationEnabled.value = false
}
// Reset OAuth states
oauth.resetState()
@@ -3326,6 +3345,7 @@ watch(
}
if (platform !== 'anthropic' || category !== 'apikey') {
anthropicPassthroughEnabled.value = false
+ webSearchEmulationEnabled.value = false
}
}
)
@@ -3690,6 +3710,7 @@ const resetForm = () => {
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
anthropicPassthroughEnabled.value = false
+ webSearchEmulationEnabled.value = false
// Reset quota control state
windowCostEnabled.value = false
windowCostLimit.value = null
@@ -3777,6 +3798,11 @@ const buildAnthropicExtra = (base?: Record): Record 0 ? extra : undefined
}
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 8b72d6d1..a67366fc 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -1149,10 +1149,61 @@
-
-
+
+
+
+
+
{{ t('admin.accounts.anthropic.webSearchEmulation') }}
+
+ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}
+
+
+
+
+
+
+
+
-
{{ t('admin.accounts.quotaLimit') }}
+
{{ t('admin.accounts.quotaControl.title') }}
+
+ {{ t('admin.accounts.quotaControl.hint') }}
+
+
+
+
+
+
+
+
{{ t('admin.accounts.quotaControl.title') }}
{{ t('admin.accounts.quotaLimitHint') }}
@@ -1237,7 +1288,7 @@
-
+
(OPENAI_WS_MODE_OFF
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
+const webSearchEmulationEnabled = ref(false)
const editQuotaLimit = ref(null)
const editQuotaDailyLimit = ref(null)
const editQuotaWeeklyLimit = ref(null)
@@ -2067,6 +2120,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
anthropicPassthroughEnabled.value = false
+ webSearchEmulationEnabled.value = false
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, {
@@ -2087,6 +2141,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
}
if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') {
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
+ webSearchEmulationEnabled.value = extra?.web_search_emulation === true
}
// Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above)
@@ -2522,8 +2577,13 @@ function loadQuotaControlSettings(account: Account) {
customBaseUrlEnabled.value = false
customBaseUrl.value = ''
- // Only applies to Anthropic OAuth/SetupToken accounts
- if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
+ // Remaining quota control settings only apply to Anthropic accounts
+ if (account.platform !== 'anthropic') {
+ return
+ }
+
+ // Window cost / session limit only apply to Anthropic OAuth/SetupToken accounts
+ if (account.type !== 'oauth' && account.type !== 'setup-token') {
return
}
@@ -2949,7 +3009,7 @@ const handleSubmit = async () => {
// For Anthropic OAuth/SetupToken accounts, handle quota control settings in extra
if (props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token')) {
- const currentExtra = (props.account.extra as Record) || {}
+ const currentExtra = (updatePayload.extra as Record) || (props.account.extra as Record) || {}
const newExtra: Record = { ...currentExtra }
// Window cost limit settings
@@ -3037,15 +3097,20 @@ const handleSubmit = async () => {
updatePayload.extra = newExtra
}
- // For Anthropic API Key accounts, handle passthrough mode in extra
+ // For Anthropic API Key accounts, handle passthrough mode + web search emulation in extra
if (props.account.platform === 'anthropic' && props.account.type === 'apikey') {
- const currentExtra = (props.account.extra as Record) || {}
+ const currentExtra = (updatePayload.extra as Record) || (props.account.extra as Record) || {}
const newExtra: Record = { ...currentExtra }
if (anthropicPassthroughEnabled.value) {
newExtra.anthropic_passthrough = true
} else {
delete newExtra.anthropic_passthrough
}
+ if (webSearchEmulationEnabled.value) {
+ newExtra.web_search_emulation = true
+ } else {
+ delete newExtra.web_search_emulation
+ }
updatePayload.extra = newExtra
}
@@ -3089,20 +3154,27 @@ const handleSubmit = async () => {
const currentExtra = (updatePayload.extra as Record) ||
(props.account.extra as Record) || {}
const newExtra: Record = { ...currentExtra }
+ // Total quota
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
newExtra.quota_limit = editQuotaLimit.value
} else {
delete newExtra.quota_limit
}
+ // Daily quota
if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) {
newExtra.quota_daily_limit = editQuotaDailyLimit.value
} else {
delete newExtra.quota_daily_limit
+ delete newExtra.quota_daily_used
+ delete newExtra.quota_daily_start
}
+ // Weekly quota
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
newExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
} else {
delete newExtra.quota_weekly_limit
+ delete newExtra.quota_weekly_used
+ delete newExtra.quota_weekly_start
}
// Quota reset mode config
if (editDailyResetMode.value === 'fixed') {
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index d6c87d52..99f8d535 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -1836,6 +1836,9 @@ export default {
defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)',
defaultImagePrice: 'Default image price (fallback when no tier matches)',
platformConfig: 'Platform Configuration',
+ webSearchEmulation: 'Web Search Emulation',
+ webSearchEmulationHint: '⚠️ When enabled, all accounts in this channel\'s Anthropic groups will intercept web_search requests. Use with caution.',
+ webSearchEmulationGlobalDisabled: 'Please enable the global switch first in Settings → Gateway → Web Search Emulation',
basicSettings: 'Basic Settings',
addPlatform: 'Add Platform',
noPlatforms: 'Click "Add Platform" to start configuring the channel',
@@ -2325,7 +2328,10 @@ export default {
anthropic: {
apiKeyPassthrough: 'Auto passthrough (auth only)',
apiKeyPassthroughDesc:
- 'Only applies to Anthropic API Key accounts. When enabled, messages/count_tokens are forwarded in passthrough mode with auth replacement only, while billing/concurrency/audit and safety filtering are preserved. Disable to roll back immediately.'
+ 'Only applies to Anthropic API Key accounts. When enabled, messages/count_tokens are forwarded in passthrough mode with auth replacement only, while billing/concurrency/audit and safety filtering are preserved. Disable to roll back immediately.',
+ webSearchEmulation: 'Web Search Emulation',
+ webSearchEmulationDesc:
+ 'Enable web search emulation for this API Key account. When a pure web_search request is detected, the gateway calls a third-party search API and constructs the response locally.',
},
modelRestriction: 'Model Restriction (Optional)',
modelWhitelist: 'Model Whitelist',
@@ -4358,6 +4364,31 @@ export default {
cchSigning: 'CCH Signing',
cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.',
},
+ webSearchEmulation: {
+ title: 'Web Search Emulation',
+ description: 'Inject web search capability for Anthropic API Key accounts that don\'t natively support it',
+ enabled: 'Enable Web Search Emulation',
+ enabledHint: 'Global switch. When disabled, web search emulation is inactive for all channels and accounts.',
+ providers: 'Search Providers',
+ addProvider: 'Add Provider',
+ providerType: 'Provider Type',
+ apiKey: 'API Key',
+ apiKeyPlaceholder: 'Enter API Key',
+ apiKeyConfigured: 'Configured',
+ priority: 'Priority',
+ priorityHint: 'Lower number = higher priority',
+ quotaLimit: 'Quota Limit',
+ quotaLimitHint: '0 = unlimited',
+ quotaRefreshInterval: 'Refresh Interval',
+ quotaUsed: 'Used',
+ proxy: 'Proxy',
+ expiresAt: 'Expires At',
+ removeProvider: 'Remove',
+ daily: 'Daily',
+ weekly: 'Weekly',
+ monthly: 'Monthly',
+ noProviders: 'No search providers configured',
+ },
site: {
title: 'Site Settings',
description: 'Customize site branding',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 2038970a..7ef7ead0 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -1915,6 +1915,9 @@ export default {
defaultPerRequestPrice: '默认单次价格(未命中层级时使用)',
defaultImagePrice: '默认图片价格(未命中层级时使用)',
platformConfig: '平台配置',
+ webSearchEmulation: 'Web Search 模拟',
+ webSearchEmulationHint: '⚠️ 开启后该渠道下所有 Anthropic 分组的账号将自动拦截 web_search 请求,请谨慎操作',
+ webSearchEmulationGlobalDisabled: '请先在系统设置 → 网关 → Web Search 模拟中启用全局开关',
basicSettings: '基础设置',
addPlatform: '添加平台',
noPlatforms: '点击"添加平台"开始配置渠道',
@@ -2472,7 +2475,10 @@ export default {
anthropic: {
apiKeyPassthrough: '自动透传(仅替换认证)',
apiKeyPassthroughDesc:
- '仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。'
+ '仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。',
+ webSearchEmulation: 'Web Search 模拟',
+ webSearchEmulationDesc:
+ '为该 API Key 账号启用 web search 模拟。客户端发送纯 web_search 请求时,由网关调用第三方搜索 API 并构造响应返回。',
},
modelRestriction: '模型限制(可选)',
modelWhitelist: '模型白名单',
@@ -4520,6 +4526,31 @@ export default {
cchSigning: 'CCH 签名',
cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。',
},
+ webSearchEmulation: {
+ title: 'Web Search 模拟',
+ description: '为不原生支持搜索的 Anthropic API Key 账号注入 web search 能力',
+ enabled: '启用 Web Search 模拟',
+ enabledHint: '全局开关。关闭后所有渠道和账号的 web search 模拟均不生效。',
+ providers: '搜索服务商',
+ addProvider: '添加服务商',
+ providerType: '服务商类型',
+ apiKey: 'API Key',
+ apiKeyPlaceholder: '输入 API Key',
+ apiKeyConfigured: '已配置',
+ priority: '优先级',
+ priorityHint: '数值越小优先级越高',
+ quotaLimit: '配额上限',
+ quotaLimitHint: '0 表示无限制',
+ quotaRefreshInterval: '刷新周期',
+ quotaUsed: '已使用',
+ proxy: '代理',
+ expiresAt: '过期时间',
+ removeProvider: '删除',
+ daily: '每日',
+ weekly: '每周',
+ monthly: '每月',
+ noProviders: '未配置搜索服务商',
+ },
site: {
title: '站点设置',
description: '自定义站点品牌',
diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue
index 5c2f153b..ce8d3c9c 100644
--- a/frontend/src/views/admin/ChannelsView.vue
+++ b/frontend/src/views/admin/ChannelsView.vue
@@ -306,6 +306,24 @@
+
+
+
+
+
+ {{ t('admin.channels.form.webSearchEmulation') }}
+
+
+ {{ t('admin.channels.form.webSearchEmulationHint') }}
+
+
+ {{ t('admin.channels.form.webSearchEmulationGlobalDisabled') }}
+
+
+
+
+
+
@@ -423,6 +441,7 @@
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
+import { extractApiErrorMessage } from '@/utils/apiError'
import { adminAPI } from '@/api/admin'
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
import type { PricingFormEntry } from '@/components/admin/channel/types'
@@ -446,6 +465,18 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
const { t } = useI18n()
const appStore = useAppStore()
+// Web Search global enabled state (loaded once on mount)
+const webSearchGlobalEnabled = ref(false)
+async function loadWebSearchGlobalState() {
+ try {
+ const cfg = await adminAPI.settings.getWebSearchEmulationConfig()
+ webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
+ } catch (err: unknown) {
+ console.warn('Failed to load web search global state:', err)
+ webSearchGlobalEnabled.value = false
+ }
+}
+
// ── Platform Section type ──
interface PlatformSection {
platform: GroupPlatform
@@ -454,6 +485,7 @@ interface PlatformSection {
group_ids: number[]
model_mapping: Record
model_pricing: PricingFormEntry[]
+ web_search_emulation: boolean
}
// ── Table columns ──
@@ -565,7 +597,8 @@ function addPlatformSection(platform: GroupPlatform) {
collapsed: false,
group_ids: [],
model_mapping: {},
- model_pricing: []
+ model_pricing: [],
+ web_search_emulation: false,
})
}
@@ -679,10 +712,14 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) {
}
// ── Form ↔ API conversion ──
-function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record> } {
+function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record>, features_config: Record } {
const group_ids: number[] = []
const model_pricing: ChannelModelPricing[] = []
const model_mapping: Record> = {}
+ // Preserve existing features_config fields not managed by the form
+ const featuresConfig: Record = editingChannel.value?.features_config
+ ? { ...editingChannel.value.features_config }
+ : {}
for (const section of form.platforms) {
if (!section.enabled) continue
@@ -711,7 +748,19 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
}
}
- return { group_ids, model_pricing, model_mapping }
+ // Collect web_search_emulation (only anthropic platform supports it)
+ const wsEmulation: Record = {}
+ for (const section of form.platforms) {
+ if (!section.enabled) continue
+ if (section.web_search_emulation && section.platform === 'anthropic') {
+ wsEmulation[section.platform] = true
+ }
+ }
+ if (Object.keys(wsEmulation).length > 0) {
+ featuresConfig.web_search_emulation = wsEmulation
+ }
+
+ return { group_ids, model_pricing, model_mapping, features_config: featuresConfig }
}
function apiToForm(channel: Channel): PlatformSection[] {
@@ -755,13 +804,19 @@ function apiToForm(channel: Channel): PlatformSection[] {
intervals: apiIntervalsToForm(p.intervals || [])
} as PricingFormEntry))
+ // Read web_search_emulation from features_config
+ const fc = channel.features_config
+ const wsEmulation = fc?.web_search_emulation as Record | undefined
+ const webSearchEnabled = wsEmulation?.[platform] === true
+
sections.push({
platform,
enabled: true,
collapsed: false,
group_ids: groupIds,
model_mapping: { ...mapping },
- model_pricing: pricing
+ model_pricing: pricing,
+ web_search_emulation: webSearchEnabled,
})
}
@@ -786,10 +841,10 @@ async function loadChannels() {
if (ctrl.signal.aborted || abortController !== ctrl) return
channels.value = response.items || []
pagination.total = response.total
- } catch (error: any) {
- if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return
- appStore.showError(t('admin.channels.loadError', 'Failed to load channels'))
- console.error('Error loading channels:', error)
+ } catch (error: unknown) {
+ const e = error as { name?: string; code?: string }
+ if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return
+ appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels')))
} finally {
if (abortController === ctrl) {
loading.value = false
@@ -969,8 +1024,7 @@ async function handleSubmit() {
}
}
- const { group_ids, model_pricing, model_mapping } = formToAPI()
- console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing))
+ const { group_ids, model_pricing, model_mapping, features_config } = formToAPI()
submitting.value = true
try {
@@ -983,7 +1037,8 @@ async function handleSubmit() {
model_pricing,
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
billing_model_source: form.billing_model_source,
- restrict_models: form.restrict_models
+ restrict_models: form.restrict_models,
+ features_config,
}
await adminAPI.channels.update(editingChannel.value.id, req)
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
@@ -995,19 +1050,18 @@ async function handleSubmit() {
model_pricing,
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
billing_model_source: form.billing_model_source,
- restrict_models: form.restrict_models
+ restrict_models: form.restrict_models,
+ features_config,
}
await adminAPI.channels.create(req)
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
}
closeDialog()
loadChannels()
- } catch (error: any) {
- const msg = error.response?.data?.detail || (editingChannel.value
+ } catch (error: unknown) {
+ appStore.showError(extractApiErrorMessage(error, editingChannel.value
? t('admin.channels.updateError', 'Failed to update channel')
- : t('admin.channels.createError', 'Failed to create channel'))
- appStore.showError(msg)
- console.error('Error saving channel:', error)
+ : t('admin.channels.createError', 'Failed to create channel')))
} finally {
submitting.value = false
}
@@ -1045,9 +1099,8 @@ async function confirmDelete() {
showDeleteDialog.value = false
deletingChannel.value = null
loadChannels()
- } catch (error: any) {
- appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel'))
- console.error('Error deleting channel:', error)
+ } catch (error: unknown) {
+ appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel')))
}
}
@@ -1055,6 +1108,7 @@ async function confirmDelete() {
onMounted(() => {
loadChannels()
loadGroups()
+ loadWebSearchGlobalState()
})
onUnmounted(() => {
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 20f9318c..6abc725a 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -630,6 +630,108 @@
{{ t('admin.settings.betaPolicy.errorMessageHint') }}
+
+
+
+
+ {{ t('admin.settings.betaPolicy.quickPresets') }}
+
+
+
+ {{ preset.label }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.betaPolicy.modelWhitelist') }}
+
+
+ {{ t('admin.settings.betaPolicy.modelWhitelistHint') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.betaPolicy.addModelPattern') }}
+
+
+
+ {{ t('admin.settings.betaPolicy.commonPatterns') }}:
+
+ {{ pattern }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.betaPolicy.fallbackAction') }}
+
+
+
+ {{ t('admin.settings.betaPolicy.fallbackActionHint') }}
+
+
+
+
+
+ {{ t('admin.settings.betaPolicy.errorMessageHint') }}
+
+
+
@@ -1022,7 +1124,327 @@
-
+
+
+
+
+
+ {{ t('admin.settings.oidc.title') }}
+
+
+ {{ t('admin.settings.oidc.description') }}
+
+
+
+
+
+
{{
+ t('admin.settings.oidc.enable')
+ }}
+
+ {{ t('admin.settings.oidc.enableHint') }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.providerName') }}
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.clientId') }}
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.clientSecret') }}
+
+
+
+ {{
+ form.oidc_connect_client_secret_configured
+ ? t('admin.settings.oidc.clientSecretConfiguredHint')
+ : t('admin.settings.oidc.clientSecretHint')
+ }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.scopes') }}
+
+
+
+ {{ t('admin.settings.oidc.scopesHint') }}
+
+
+
+
+
+ {{ t('admin.settings.oidc.redirectUrl') }}
+
+
+
+
+ {{ t('admin.settings.oidc.quickSetCopy') }}
+
+
+ {{ oidcRedirectUrlSuggestion }}
+
+
+
+ {{ t('admin.settings.oidc.redirectUrlHint') }}
+
+
+
+
+
+ {{ t('admin.settings.oidc.frontendRedirectUrl') }}
+
+
+
+ {{ t('admin.settings.oidc.frontendRedirectUrlHint') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.tokenAuthMethod') }}
+
+
+ client_secret_post
+ client_secret_basic
+ none
+
+
+
+
+
+ {{ t('admin.settings.oidc.clockSkewSeconds') }}
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.allowedSigningAlgs') }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.usePkce') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.validateIdToken') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.requireEmailVerified') }}
+
+
+
+
+
+
+
+
+
+
+
@@ -1274,8 +1696,122 @@
+
+
+
+
+
+ {{ t('admin.settings.gatewayForwarding.cchSigning') }}
+
+
+ {{ t('admin.settings.gatewayForwarding.cchSigningHint') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.title') }}
+
+
+ {{ t('admin.settings.webSearchEmulation.description') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.enabled') }}
+
+
+ {{ t('admin.settings.webSearchEmulation.enabledHint') }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.providers') }}
+
+
+ {{ t('admin.settings.webSearchEmulation.addProvider') }}
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.noProviders') }}
+
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.removeProvider') }}
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.apiKey') }}
+
+
+
+
{{ t('admin.settings.webSearchEmulation.priority') }}
+
+
{{ t('admin.settings.webSearchEmulation.priorityHint') }}
+
+
+
{{ t('admin.settings.webSearchEmulation.quotaLimit') }}
+
+
{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}
+
+ {{ t('admin.settings.webSearchEmulation.quotaUsed') }}: {{ provider.quota_used }} / {{ provider.quota_limit || '∞' }}
+
+
+
+ {{ t('admin.settings.webSearchEmulation.quotaRefreshInterval') }}
+
+
+
+
+
+
{{ t('admin.settings.webSearchEmulation.proxy') }}
+
+
+
+
+
+
+
@@ -1353,6 +1889,48 @@
+
+
+
+ {{ t('admin.settings.site.tablePreferencesTitle') }}
+
+
+ {{ t('admin.settings.site.tablePreferencesDescription') }}
+
+
+
+
+ {{ t('admin.settings.site.tableDefaultPageSize') }}
+
+
+
+ {{ t('admin.settings.site.tableDefaultPageSizeHint') }}
+
+
+
+
+ {{ t('admin.settings.site.tablePageSizeOptions') }}
+
+
+
+ {{ t('admin.settings.site.tablePageSizeOptionsHint') }}
+
+
+
+
+
@@ -1644,7 +2222,13 @@
-
+
{{ t('admin.settings.payment.maxPendingOrders') }}
{{ t('admin.settings.payment.loadBalanceStrategy') }}
-
-
-
-
-
-
- {{ t('admin.settings.payment.cancelRateLimit') }}
-
-
-
-
-
{{ t('admin.settings.payment.cancelRateLimitEvery') }}
-
-
-
{{ t('admin.settings.payment.cancelRateLimitAllowMax') }}
-
-
{{ t('admin.settings.payment.cancelRateLimitTimes') }}
-
+
-
{{ t('admin.settings.payment.cancelRateLimitHint') }}
@@ -2029,7 +2617,9 @@ import { adminAPI } from '@/api'
import type {
SystemSettings,
UpdateSettingsRequest,
- DefaultSubscriptionSetting
+ DefaultSubscriptionSetting,
+ WebSearchEmulationConfig,
+ WebSearchProviderConfig,
} from '@/api/admin/settings'
import type { AdminGroup } from '@/types'
import type { ProviderInstance } from '@/types/payment'
@@ -2042,6 +2632,7 @@ import PaymentProviderDialog from '@/components/payment/PaymentProviderDialog.vu
import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import Toggle from '@/components/common/Toggle.vue'
+import ProxySelector from '@/components/common/ProxySelector.vue'
import ImageUpload from '@/components/common/ImageUpload.vue'
import BackupSettings from '@/views/admin/BackupView.vue'
import { useClipboard } from '@/composables/useClipboard'
@@ -2055,11 +2646,11 @@ import {
parseRegistrationEmailSuffixWhitelistInput
} from '@/utils/registrationEmailPolicy'
-const { t } = useI18n()
+const { t, locale } = useI18n()
const appStore = useAppStore()
const adminSettingsStore = useAdminSettingsStore()
-type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup' | 'data'
+type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup'
const activeTab = ref
('general')
const settingsTabs = [
{ key: 'general' as SettingsTab, icon: 'home' as const },
@@ -2081,6 +2672,7 @@ const smtpPasswordManuallyEdited = ref(false)
const testEmailAddress = ref('')
const registrationEmailSuffixWhitelistTags = ref([])
const registrationEmailSuffixWhitelistDraft = ref('')
+const tablePageSizeOptionsInput = ref('10, 20, 50, 100')
// Admin API Key 状态
const adminApiKeyLoading = ref(true)
@@ -2129,9 +2721,16 @@ const betaPolicyForm = reactive({
action: 'pass' | 'filter' | 'block'
scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
error_message?: string
+ model_whitelist?: string[]
+ fallback_action?: 'pass' | 'filter' | 'block'
+ fallback_error_message?: string
}>
})
+const tablePageSizeMin = 5
+const tablePageSizeMax = 1000
+const tablePageSizeDefault = 20
+
interface DefaultSubscriptionGroupOption {
value: number
label: string
@@ -2146,6 +2745,7 @@ type SettingsForm = SystemSettings & {
smtp_password: string
turnstile_secret_key: string
linuxdo_connect_client_secret: string
+ oidc_connect_client_secret: string
}
const form = reactive({
@@ -2170,7 +2770,8 @@ const form = reactive({
backend_mode_enabled: false,
hide_ccs_import_button: false,
payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling',
- sora_client_enabled: false,
+ table_default_page_size: tablePageSizeDefault,
+ table_page_size_options: [10, 20, 50, 100],
custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>,
custom_endpoints: [] as Array<{name: string; endpoint: string; description: string}>,
frontend_url: '',
@@ -2193,6 +2794,30 @@ const form = reactive({
linuxdo_connect_client_secret: '',
linuxdo_connect_client_secret_configured: false,
linuxdo_connect_redirect_url: '',
+ // Generic OIDC OAuth 登录
+ oidc_connect_enabled: false,
+ oidc_connect_provider_name: 'OIDC',
+ oidc_connect_client_id: '',
+ oidc_connect_client_secret: '',
+ oidc_connect_client_secret_configured: false,
+ oidc_connect_issuer_url: '',
+ oidc_connect_discovery_url: '',
+ oidc_connect_authorize_url: '',
+ oidc_connect_token_url: '',
+ oidc_connect_userinfo_url: '',
+ oidc_connect_jwks_url: '',
+ oidc_connect_scopes: 'openid email profile',
+ oidc_connect_redirect_url: '',
+ oidc_connect_frontend_redirect_url: '/auth/oidc/callback',
+ oidc_connect_token_auth_method: 'client_secret_post',
+ oidc_connect_use_pkce: false,
+ oidc_connect_validate_id_token: true,
+ oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256',
+ oidc_connect_clock_skew_seconds: 120,
+ oidc_connect_require_email_verified: false,
+ oidc_connect_userinfo_email_path: '',
+ oidc_connect_userinfo_id_path: '',
+ oidc_connect_userinfo_username_path: '',
// Model fallback
enable_model_fallback: false,
fallback_model_anthropic: 'claude-3-5-sonnet-20241022',
@@ -2214,9 +2839,60 @@ const form = reactive({
allow_ungrouped_key_scheduling: false,
// Gateway forwarding behavior
enable_fingerprint_unification: true,
- enable_metadata_passthrough: false
+ enable_metadata_passthrough: false,
+ enable_cch_signing: false
})
+// Web Search Emulation config (loaded/saved separately)
+const DEFAULT_WEB_SEARCH_QUOTA_LIMIT = 1000
+
+const webSearchConfig = reactive({
+ enabled: false,
+ providers: [],
+})
+
+function addWebSearchProvider() {
+ webSearchConfig.providers.push({
+ type: 'brave',
+ api_key: '',
+ api_key_configured: false,
+ priority: webSearchConfig.providers.length + 1,
+ quota_limit: DEFAULT_WEB_SEARCH_QUOTA_LIMIT,
+ quota_refresh_interval: 'monthly',
+ proxy_id: null,
+ expires_at: null,
+ } as WebSearchProviderConfig)
+}
+
+async function loadWebSearchConfig() {
+ try {
+ const resp = await adminAPI.settings.getWebSearchEmulationConfig()
+ if (resp) {
+ webSearchConfig.enabled = resp.enabled || false
+ webSearchConfig.providers = resp.providers || []
+ }
+ } catch (err: unknown) {
+ // 404 is expected when config hasn't been created yet; show error for other failures
+ const status = (err as { status?: number })?.status
+ if (status !== 404 && status !== undefined) {
+ appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ }
+ }
+}
+
+async function saveWebSearchConfig(): Promise {
+ try {
+ await adminAPI.settings.updateWebSearchEmulationConfig({
+ enabled: webSearchConfig.enabled,
+ providers: webSearchConfig.providers as WebSearchProviderConfig[],
+ })
+ return true
+ } catch (err: unknown) {
+ appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ return false
+ }
+}
+
const defaultSubscriptionGroupOptions = computed(() =>
subscriptionGroups.value.map((group) => ({
value: group.id,
@@ -2312,6 +2988,21 @@ async function setAndCopyLinuxdoRedirectUrl() {
await copyToClipboard(url, t('admin.settings.linuxdo.redirectUrlSetAndCopied'))
}
+const oidcRedirectUrlSuggestion = computed(() => {
+ if (typeof window === 'undefined') return ''
+ const origin =
+ window.location.origin || `${window.location.protocol}//${window.location.host}`
+ return `${origin}/api/v1/auth/oauth/oidc/callback`
+})
+
+async function setAndCopyOIDCRedirectUrl() {
+ const url = oidcRedirectUrlSuggestion.value
+ if (!url) return
+
+ form.oidc_connect_redirect_url = url
+ await copyToClipboard(url, t('admin.settings.oidc.redirectUrlSetAndCopied'))
+}
+
// Custom menu item management
function addMenuItem() {
form.custom_menu_items.push({
@@ -2354,6 +3045,35 @@ function removeEndpoint(index: number) {
form.custom_endpoints.splice(index, 1)
}
+function formatTablePageSizeOptions(options: number[]): string {
+ return options.join(', ')
+}
+
+function parseTablePageSizeOptionsInput(raw: string): number[] | null {
+ const tokens = raw
+ .split(',')
+ .map((token) => token.trim())
+ .filter((token) => token.length > 0)
+
+ if (tokens.length === 0) {
+ return null
+ }
+
+ const parsed = tokens.map((token) => Number(token))
+ if (parsed.some((value) => !Number.isInteger(value))) {
+ return null
+ }
+
+ const deduped = Array.from(new Set(parsed)).sort((a, b) => a - b)
+ if (
+ deduped.some((value) => value < tablePageSizeMin || value > tablePageSizeMax)
+ ) {
+ return null
+ }
+
+ return deduped
+}
+
async function loadSettings() {
loading.value = true
loadFailed.value = false
@@ -2378,11 +3098,18 @@ async function loadSettings() {
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
settings.registration_email_suffix_whitelist
)
+ tablePageSizeOptionsInput.value = formatTablePageSizeOptions(
+ Array.isArray(settings.table_page_size_options) ? settings.table_page_size_options : [10, 20, 50, 100]
+ )
registrationEmailSuffixWhitelistDraft.value = ''
form.smtp_password = ''
smtpPasswordManuallyEdited.value = false
form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = ''
+ form.oidc_connect_client_secret = ''
+
+ // Load web search emulation config separately
+ await loadWebSearchConfig()
} catch (error: unknown) {
loadFailed.value = true
appStore.showError(extractApiErrorMessage(error, t('admin.settings.failedToLoad')))
@@ -2420,6 +3147,37 @@ function removeDefaultSubscription(index: number) {
async function saveSettings() {
saving.value = true
try {
+ const normalizedTableDefaultPageSize = Math.floor(Number(form.table_default_page_size))
+ if (
+ !Number.isInteger(normalizedTableDefaultPageSize) ||
+ normalizedTableDefaultPageSize < tablePageSizeMin ||
+ normalizedTableDefaultPageSize > tablePageSizeMax
+ ) {
+ appStore.showError(
+ t('admin.settings.site.tableDefaultPageSizeRangeError', {
+ min: tablePageSizeMin,
+ max: tablePageSizeMax
+ })
+ )
+ return
+ }
+
+ const normalizedTablePageSizeOptions = parseTablePageSizeOptionsInput(
+ tablePageSizeOptionsInput.value
+ )
+ if (!normalizedTablePageSizeOptions) {
+ appStore.showError(
+ t('admin.settings.site.tablePageSizeOptionsFormatError', {
+ min: tablePageSizeMin,
+ max: tablePageSizeMax
+ })
+ )
+ return
+ }
+
+ form.table_default_page_size = normalizedTableDefaultPageSize
+ form.table_page_size_options = normalizedTablePageSizeOptions
+
const normalizedDefaultSubscriptions = form.default_subscriptions
.filter((item) => item.group_id > 0 && item.validity_days > 0)
.map((item: DefaultSubscriptionSetting) => ({
@@ -2480,6 +3238,8 @@ async function saveSettings() {
home_content: form.home_content,
backend_mode_enabled: form.backend_mode_enabled,
hide_ccs_import_button: form.hide_ccs_import_button,
+ table_default_page_size: form.table_default_page_size,
+ table_page_size_options: form.table_page_size_options,
custom_menu_items: form.custom_menu_items,
custom_endpoints: form.custom_endpoints,
frontend_url: form.frontend_url,
@@ -2497,6 +3257,28 @@ async function saveSettings() {
linuxdo_connect_client_id: form.linuxdo_connect_client_id,
linuxdo_connect_client_secret: form.linuxdo_connect_client_secret || undefined,
linuxdo_connect_redirect_url: form.linuxdo_connect_redirect_url,
+ oidc_connect_enabled: form.oidc_connect_enabled,
+ oidc_connect_provider_name: form.oidc_connect_provider_name,
+ oidc_connect_client_id: form.oidc_connect_client_id,
+ oidc_connect_client_secret: form.oidc_connect_client_secret || undefined,
+ oidc_connect_issuer_url: form.oidc_connect_issuer_url,
+ oidc_connect_discovery_url: form.oidc_connect_discovery_url,
+ oidc_connect_authorize_url: form.oidc_connect_authorize_url,
+ oidc_connect_token_url: form.oidc_connect_token_url,
+ oidc_connect_userinfo_url: form.oidc_connect_userinfo_url,
+ oidc_connect_jwks_url: form.oidc_connect_jwks_url,
+ oidc_connect_scopes: form.oidc_connect_scopes,
+ oidc_connect_redirect_url: form.oidc_connect_redirect_url,
+ oidc_connect_frontend_redirect_url: form.oidc_connect_frontend_redirect_url,
+ oidc_connect_token_auth_method: form.oidc_connect_token_auth_method,
+ oidc_connect_use_pkce: form.oidc_connect_use_pkce,
+ oidc_connect_validate_id_token: form.oidc_connect_validate_id_token,
+ oidc_connect_allowed_signing_algs: form.oidc_connect_allowed_signing_algs,
+ oidc_connect_clock_skew_seconds: form.oidc_connect_clock_skew_seconds,
+ oidc_connect_require_email_verified: form.oidc_connect_require_email_verified,
+ oidc_connect_userinfo_email_path: form.oidc_connect_userinfo_email_path,
+ oidc_connect_userinfo_id_path: form.oidc_connect_userinfo_id_path,
+ oidc_connect_userinfo_username_path: form.oidc_connect_userinfo_username_path,
enable_model_fallback: form.enable_model_fallback,
fallback_model_anthropic: form.fallback_model_anthropic,
fallback_model_openai: form.fallback_model_openai,
@@ -2509,6 +3291,7 @@ async function saveSettings() {
allow_ungrouped_key_scheduling: form.allow_ungrouped_key_scheduling,
enable_fingerprint_unification: form.enable_fingerprint_unification,
enable_metadata_passthrough: form.enable_metadata_passthrough,
+ enable_cch_signing: form.enable_cch_signing,
// Payment configuration
payment_enabled: form.payment_enabled,
payment_min_amount: Number(form.payment_min_amount) || 0,
@@ -2539,15 +3322,23 @@ async function saveSettings() {
registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains(
updated.registration_email_suffix_whitelist
)
+ tablePageSizeOptionsInput.value = formatTablePageSizeOptions(
+ Array.isArray(updated.table_page_size_options) ? updated.table_page_size_options : [10, 20, 50, 100]
+ )
registrationEmailSuffixWhitelistDraft.value = ''
form.smtp_password = ''
smtpPasswordManuallyEdited.value = false
form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = ''
+ form.oidc_connect_client_secret = ''
+ // Save web search emulation config separately (errors handled internally)
+ const wsOk = await saveWebSearchConfig()
// Refresh cached settings so sidebar/header update immediately
await appStore.fetchPublicSettings(true)
await adminSettingsStore.fetch(true)
- appStore.showSuccess(t('admin.settings.settingsSaved'))
+ if (wsOk) {
+ appStore.showSuccess(t('admin.settings.settingsSaved'))
+ }
} catch (error: unknown) {
appStore.showError(extractApiErrorMessage(error, t('admin.settings.failedToSave')))
} finally {
@@ -2785,10 +3576,48 @@ const betaDisplayNames: Record = {
'context-1m-2025-08-07': 'Context 1M'
}
+// 快捷预设:按 beta_token 定义预设方案
+const betaPresets: Record> = {
+ 'context-1m-2025-08-07': [
+ {
+ label: t('admin.settings.betaPolicy.presetOpusOnly'),
+ description: t('admin.settings.betaPolicy.presetOpusOnlyDesc'),
+ action: 'pass',
+ model_whitelist: ['claude-opus-4-6'],
+ fallback_action: 'filter',
+ },
+ ],
+}
+
+// 常用模型模式(具体 ID + 通配符示例)
+const commonModelPatterns = ['claude-opus-4-6', 'claude-sonnet-4-6', 'claude-opus-*', 'claude-sonnet-*']
+
function getBetaDisplayName(token: string): string {
return betaDisplayNames[token] || token
}
+function applyBetaPreset(
+ rule: (typeof betaPolicyForm.rules)[number],
+ preset: { action: 'pass' | 'filter' | 'block'; model_whitelist: string[]; fallback_action: 'pass' | 'filter' | 'block' }
+) {
+ rule.action = preset.action
+ rule.model_whitelist = [...preset.model_whitelist]
+ rule.fallback_action = preset.fallback_action
+}
+
+function addQuickPattern(rule: (typeof betaPolicyForm.rules)[number], pattern: string) {
+ if (!rule.model_whitelist) rule.model_whitelist = []
+ if (!rule.model_whitelist.includes(pattern)) {
+ rule.model_whitelist.push(pattern)
+ }
+}
+
async function loadBetaPolicySettings() {
betaPolicyLoading.value = true
try {
@@ -2804,8 +3633,22 @@ async function loadBetaPolicySettings() {
async function saveBetaPolicySettings() {
betaPolicySaving.value = true
try {
+ // Clean up empty patterns before saving
+ const cleanedRules = betaPolicyForm.rules.map(rule => {
+ const whitelist = rule.model_whitelist?.filter(p => p.trim() !== '')
+ const hasWhitelist = whitelist && whitelist.length > 0
+ return {
+ beta_token: rule.beta_token,
+ action: rule.action,
+ scope: rule.scope,
+ error_message: rule.error_message,
+ model_whitelist: hasWhitelist ? whitelist : undefined,
+ fallback_action: hasWhitelist ? (rule.fallback_action || 'pass') : undefined,
+ fallback_error_message: hasWhitelist && rule.fallback_action === 'block' ? rule.fallback_error_message : undefined,
+ }
+ })
const updated = await adminAPI.settings.updateBetaPolicySettings({
- rules: betaPolicyForm.rules
+ rules: cleanedRules
})
betaPolicyForm.rules = updated.rules
appStore.showSuccess(t('admin.settings.betaPolicy.saved'))
@@ -2919,15 +3762,15 @@ async function handleSaveProvider(payload: Partial) {
providerSaving.value = true
try {
if (editingProvider.value) {
- const updated = await adminAPI.payment.updateProvider(editingProvider.value.id, payload)
- // Update in place to preserve list order
- const idx = providers.value.findIndex(p => p.id === editingProvider.value!.id)
- if (idx >= 0 && updated.data) providers.value[idx] = updated.data
+ await adminAPI.payment.updateProvider(editingProvider.value.id, payload)
} else {
await adminAPI.payment.createProvider(payload)
- loadProviders()
}
showProviderDialog.value = false
+ // Reload full list (API returns decrypted/formatted data with correct sort order)
+ await loadProviders()
+ // Auto-save settings so provider changes take effect immediately
+ await saveSettings()
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value))
} finally {
From 7fad9f604fb081fd850034b4d2bdf8f30dde04fd Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 00:09:28 +0800
Subject: [PATCH 15/88] fix(test): add web_search_emulation_enabled to API
contract test
The settings API response now includes the new field; update the
expected snapshot in TestAPIContracts to match.
---
backend/internal/server/api_contract_test.go | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 1a4892fa..08291faa 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -204,11 +204,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
- "claude_code_only": false,
+ "claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
- "allow_messages_dispatch": false,
"require_oauth_only": false,
"require_privacy_set": false,
"created_at": "2025-01-02T03:04:05Z",
@@ -587,26 +586,27 @@ func TestAPIContracts(t *testing.T) {
"enable_cch_signing": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
+ "web_search_emulation_enabled": false,
+ "custom_menu_items": [],
+ "custom_endpoints": [],
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
- "payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
+ "payment_enabled_types": null,
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
- "payment_cancel_rate_limit_window_mode": "",
- "custom_menu_items": [],
- "custom_endpoints": []
+ "payment_cancel_rate_limit_window_mode": ""
}
}`,
},
From 7535e312e009cc67301d6ef2ac6dc3069259515c Mon Sep 17 00:00:00 2001
From: erio
Date: Sat, 11 Apr 2026 23:39:49 +0800
Subject: [PATCH 16/88] feat(channels): add custom account stats pricing rules
Allow channels to configure independent model pricing for account
statistics cost calculation, decoupled from user billing.
Backend:
- Migration 101: channels.apply_pricing_to_account_stats toggle,
channel_account_stats_pricing_rules/model_pricing tables,
usage_logs.account_stats_cost column
- resolveAccountStatsCost: match rules by group/account, then channel
pricing, fallback to original formula when unconfigured
- Integrate into both GatewayService.recordUsageCore and
OpenAIGatewayService.RecordUsage
- Update 8 account stats SQL queries to use
COALESCE(account_stats_cost, total_cost) * account_rate_multiplier
- 23 unit tests for matching, pricing lookup, and cost calculation
Frontend:
- Channel edit dialog: toggle + custom rules UI with group/account
multi-select and pricing entry cards
- API types and i18n (zh/en)
---
.../internal/handler/admin/channel_handler.go | 168 ++++---
backend/internal/repository/channel_repo.go | 91 ++--
.../channel_repo_account_stats_pricing.go | 170 +++++++
backend/internal/repository/usage_log_repo.go | 39 +-
.../usage_log_repo_request_type_test.go | 29 +-
.../internal/service/account_stats_pricing.go | 192 ++++++++
.../service/account_stats_pricing_test.go | 430 ++++++++++++++++++
backend/internal/service/channel.go | 50 +-
backend/internal/service/channel_service.go | 77 ++--
backend/internal/service/gateway_service.go | 17 +
.../service/openai_gateway_service.go | 9 +
backend/internal/service/usage_log.go | 2 +
.../101_add_account_stats_pricing.sql | 38 ++
frontend/src/api/admin/channels.ts | 17 +-
frontend/src/i18n/locales/en.ts | 13 +-
frontend/src/i18n/locales/zh.ts | 13 +-
frontend/src/views/admin/ChannelsView.vue | 338 +++++++++++---
17 files changed, 1449 insertions(+), 244 deletions(-)
create mode 100644 backend/internal/repository/channel_repo_account_stats_pricing.go
create mode 100644 backend/internal/service/account_stats_pricing.go
create mode 100644 backend/internal/service/account_stats_pricing_test.go
create mode 100644 backend/migrations/101_add_account_stats_pricing.sql
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
index 9cefc792..2d4cd56a 100644
--- a/backend/internal/handler/admin/channel_handler.go
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -26,28 +26,30 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
// --- Request / Response types ---
type createChannelRequest struct {
- Name string `json:"name" binding:"required,max=100"`
- Description string `json:"description"`
- GroupIDs []int64 `json:"group_ids"`
- ModelPricing []channelModelPricingRequest `json:"model_pricing"`
- ModelMapping map[string]map[string]string `json:"model_mapping"`
- BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
- RestrictModels bool `json:"restrict_models"`
- Features string `json:"features"`
- FeaturesConfig map[string]any `json:"features_config"`
+ Name string `json:"name" binding:"required,max=100"`
+ Description string `json:"description"`
+ GroupIDs []int64 `json:"group_ids"`
+ ModelPricing []channelModelPricingRequest `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
+ RestrictModels bool `json:"restrict_models"`
+ Features string `json:"features"`
+ ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
}
type updateChannelRequest struct {
- Name string `json:"name" binding:"omitempty,max=100"`
- Description *string `json:"description"`
- Status string `json:"status" binding:"omitempty,oneof=active disabled"`
- GroupIDs *[]int64 `json:"group_ids"`
- ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
- ModelMapping map[string]map[string]string `json:"model_mapping"`
- BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
- RestrictModels *bool `json:"restrict_models"`
- Features *string `json:"features"`
- FeaturesConfig map[string]any `json:"features_config"`
+ Name string `json:"name" binding:"omitempty,max=100"`
+ Description *string `json:"description"`
+ Status string `json:"status" binding:"omitempty,oneof=active disabled"`
+ GroupIDs *[]int64 `json:"group_ids"`
+ ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
+ RestrictModels *bool `json:"restrict_models"`
+ Features *string `json:"features"`
+ ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
}
type channelModelPricingRequest struct {
@@ -75,20 +77,28 @@ type pricingIntervalRequest struct {
SortOrder int `json:"sort_order"`
}
+type accountStatsPricingRuleRequest struct {
+ Name string `json:"name"`
+ GroupIDs []int64 `json:"group_ids"`
+ AccountIDs []int64 `json:"account_ids"`
+ Pricing []channelModelPricingRequest `json:"pricing"`
+}
+
type channelResponse struct {
- ID int64 `json:"id"`
- Name string `json:"name"`
- Description string `json:"description"`
- Status string `json:"status"`
- BillingModelSource string `json:"billing_model_source"`
- RestrictModels bool `json:"restrict_models"`
- Features string `json:"features"`
- FeaturesConfig map[string]any `json:"features_config"`
- GroupIDs []int64 `json:"group_ids"`
- ModelPricing []channelModelPricingResponse `json:"model_pricing"`
- ModelMapping map[string]map[string]string `json:"model_mapping"`
- CreatedAt string `json:"created_at"`
- UpdatedAt string `json:"updated_at"`
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Status string `json:"status"`
+ BillingModelSource string `json:"billing_model_source"`
+ RestrictModels bool `json:"restrict_models"`
+ Features string `json:"features"`
+ GroupIDs []int64 `json:"group_ids"`
+ ModelPricing []channelModelPricingResponse `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
}
type channelModelPricingResponse struct {
@@ -118,6 +128,14 @@ type pricingIntervalResponse struct {
SortOrder int `json:"sort_order"`
}
+type accountStatsPricingRuleResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ GroupIDs []int64 `json:"group_ids"`
+ AccountIDs []int64 `json:"account_ids"`
+ Pricing []channelModelPricingResponse `json:"pricing"`
+}
+
func channelToResponse(ch *service.Channel) *channelResponse {
if ch == nil {
return nil
@@ -129,7 +147,6 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Status: ch.Status,
RestrictModels: ch.RestrictModels,
Features: ch.Features,
- FeaturesConfig: ch.FeaturesConfig,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
@@ -150,6 +167,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
for _, p := range ch.ModelPricing {
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
}
+
+ resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
+ resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
+ for _, rule := range ch.AccountStatsPricingRules {
+ ruleResp := accountStatsPricingRuleResponse{
+ ID: rule.ID,
+ Name: rule.Name,
+ GroupIDs: rule.GroupIDs,
+ AccountIDs: rule.AccountIDs,
+ Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
+ }
+ if ruleResp.GroupIDs == nil {
+ ruleResp.GroupIDs = []int64{}
+ }
+ if ruleResp.AccountIDs == nil {
+ ruleResp.AccountIDs = []int64{}
+ }
+ for i := range rule.Pricing {
+ ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
+ }
+ resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
+ }
+
return resp
}
@@ -241,6 +281,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result
}
+func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
+ return service.AccountStatsPricingRule{
+ Name: r.Name,
+ GroupIDs: r.GroupIDs,
+ AccountIDs: r.AccountIDs,
+ Pricing: pricingRequestToService(r.Pricing),
+ }
+}
+
// --- Handlers ---
// List handles listing channels with pagination
@@ -300,16 +349,24 @@ func (h *ChannelHandler) Create(c *gin.Context) {
pricing := pricingRequestToService(req.ModelPricing)
+ var statsRules []service.AccountStatsPricingRule
+ for i, r := range req.AccountStatsPricingRules {
+ rule := accountStatsPricingRuleRequestToService(r)
+ rule.SortOrder = i
+ statsRules = append(statsRules, rule)
+ }
+
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
- Name: req.Name,
- Description: req.Description,
- GroupIDs: req.GroupIDs,
- ModelPricing: pricing,
- ModelMapping: req.ModelMapping,
- BillingModelSource: req.BillingModelSource,
- RestrictModels: req.RestrictModels,
- Features: req.Features,
- FeaturesConfig: req.FeaturesConfig,
+ Name: req.Name,
+ Description: req.Description,
+ GroupIDs: req.GroupIDs,
+ ModelPricing: pricing,
+ ModelMapping: req.ModelMapping,
+ BillingModelSource: req.BillingModelSource,
+ RestrictModels: req.RestrictModels,
+ Features: req.Features,
+ ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
+ AccountStatsPricingRules: statsRules,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -335,20 +392,29 @@ func (h *ChannelHandler) Update(c *gin.Context) {
}
input := &service.UpdateChannelInput{
- Name: req.Name,
- Description: req.Description,
- Status: req.Status,
- GroupIDs: req.GroupIDs,
- ModelMapping: req.ModelMapping,
- BillingModelSource: req.BillingModelSource,
- RestrictModels: req.RestrictModels,
- Features: req.Features,
- FeaturesConfig: req.FeaturesConfig,
+ Name: req.Name,
+ Description: req.Description,
+ Status: req.Status,
+ GroupIDs: req.GroupIDs,
+ ModelMapping: req.ModelMapping,
+ BillingModelSource: req.BillingModelSource,
+ RestrictModels: req.RestrictModels,
+ Features: req.Features,
+ ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
input.ModelPricing = &pricing
}
+ if req.AccountStatsPricingRules != nil {
+ statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
+ for i, r := range *req.AccountStatsPricingRules {
+ rule := accountStatsPricingRuleRequestToService(r)
+ rule.SortOrder = i
+ statsRules = append(statsRules, rule)
+ }
+ input.AccountStatsPricingRules = &statsRules
+ }
channel, err := h.channelService.Update(c.Request.Context(), id, input)
if err != nil {
diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go
index 56b5cc71..583ce895 100644
--- a/backend/internal/repository/channel_repo.go
+++ b/backend/internal/repository/channel_repo.go
@@ -41,14 +41,10 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
- featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
- if err != nil {
- return err
- }
err = tx.QueryRowContext(ctx,
- `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
+ `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at, updated_at`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -71,17 +67,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
}
}
+ // 设置账号统计定价规则
+ if len(channel.AccountStatsPricingRules) > 0 {
+ if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
+ return err
+ }
+ }
+
return nil
})
}
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
- var modelMappingJSON, featuresConfigJSON []byte
+ var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at
FROM channels WHERE id = $1`, id,
- ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
+ ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
@@ -89,7 +92,6 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
- ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
@@ -103,6 +105,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
}
ch.ModelPricing = pricing
+ statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ ch.AccountStatsPricingRules = statsPricingRules
+
return ch, nil
}
@@ -112,14 +120,10 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
- featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
- if err != nil {
- return err
- }
result, err := tx.ExecContext(ctx,
- `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
+ `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW()
WHERE id = $9`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -146,6 +150,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
}
}
+ // 更新账号统计定价规则
+ if channel.AccountStatsPricingRules != nil {
+ if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
+ return err
+ }
+ }
+
return nil
})
}
@@ -196,7 +207,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
- `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
+ `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
)
@@ -212,12 +223,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64
for rows.Next() {
var ch service.Channel
- var modelMappingJSON, featuresConfigJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ var modelMappingJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
- ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -235,9 +245,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
if err != nil {
return nil, nil, err
}
+ statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
+ if err != nil {
+ return nil, nil, err
+ }
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
+ channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
}
}
@@ -283,7 +298,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -294,12 +309,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64
for rows.Next() {
var ch service.Channel
- var modelMappingJSON, featuresConfigJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ var modelMappingJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
- ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -323,9 +337,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
return nil, err
}
+ // 批量加载账号统计定价规则
+ statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
+ if err != nil {
+ return nil, err
+ }
+
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
+ channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
}
return channels, nil
@@ -467,28 +488,6 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
return m
}
-func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
- if len(m) == 0 {
- return []byte("{}"), nil
- }
- data, err := json.Marshal(m)
- if err != nil {
- return nil, fmt.Errorf("marshal features_config: %w", err)
- }
- return data, nil
-}
-
-func unmarshalFeaturesConfig(data []byte) map[string]any {
- if len(data) == 0 {
- return nil
- }
- var m map[string]any
- if err := json.Unmarshal(data, &m); err != nil {
- return nil
- }
- return m
-}
-
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {
diff --git a/backend/internal/repository/channel_repo_account_stats_pricing.go b/backend/internal/repository/channel_repo_account_stats_pricing.go
new file mode 100644
index 00000000..ef8f5177
--- /dev/null
+++ b/backend/internal/repository/channel_repo_account_stats_pricing.go
@@ -0,0 +1,170 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+// --- 账号统计定价规则 ---
+
+// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
+func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
+ // 1. 查询规则
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
+ FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
+ pq.Array(channelIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var allRules []service.AccountStatsPricingRule
+ var ruleIDs []int64
+ for rows.Next() {
+ var rule service.AccountStatsPricingRule
+ if err := rows.Scan(
+ &rule.ID, &rule.ChannelID, &rule.Name,
+ pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
+ &rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
+ }
+ ruleIDs = append(ruleIDs, rule.ID)
+ allRules = append(allRules, rule)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
+ }
+
+ // 2. 批量加载规则的模型定价
+ pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // 3. 按 channelID 分组并关联定价
+ result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
+ for i := range allRules {
+ allRules[i].Pricing = pricingMap[allRules[i].ID]
+ result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
+ }
+
+ return result, nil
+}
+
+// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
+func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
+ if len(ruleIDs) == 0 {
+ return make(map[int64][]service.ChannelModelPricing), nil
+ }
+
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
+ cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
+ FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
+ pq.Array(ruleIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
+ for rows.Next() {
+ var p service.ChannelModelPricing
+ var ruleID int64
+ var modelsJSON []byte
+ if err := rows.Scan(
+ &p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
+ &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
+ &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan account stats model pricing: %w", err)
+ }
+ if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
+ p.Models = []string{}
+ }
+ pricingMap[ruleID] = append(pricingMap[ruleID], p)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
+ }
+ return pricingMap, nil
+}
+
+// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
+func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
+ result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
+ if err != nil {
+ return nil, err
+ }
+ return result[channelID], nil
+}
+
+// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
+func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
+ // CASCADE 会自动删除关联的 model_pricing
+ if _, err := tx.ExecContext(ctx,
+ `DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
+ ); err != nil {
+ return fmt.Errorf("delete old account stats pricing rules: %w", err)
+ }
+
+ for i := range rules {
+ rules[i].ChannelID = channelID
+ if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
+ return fmt.Errorf("insert account stats pricing rule: %w", err)
+ }
+ }
+ return nil
+}
+
+// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
+func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
+ err := tx.QueryRowContext(ctx,
+ `INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
+ VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
+ rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
+ ).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
+ if err != nil {
+ return fmt.Errorf("insert account stats pricing rule: %w", err)
+ }
+
+ for j := range rule.Pricing {
+ if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
+func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
+ modelsJSON, err := json.Marshal(pricing.Models)
+ if err != nil {
+ return fmt.Errorf("marshal models: %w", err)
+ }
+ billingMode := pricing.BillingMode
+ if billingMode == "" {
+ billingMode = service.BillingModeToken
+ }
+ platform := pricing.Platform
+ err = tx.QueryRowContext(ctx,
+ `INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
+ ruleID, platform, modelsJSON, billingMode,
+ pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
+ pricing.ImageOutputPrice, pricing.PerRequestPrice,
+ ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
+ if err != nil {
+ return fmt.Errorf("insert account stats model pricing: %w", err)
+ }
+ return nil
+}
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index 3ba2191e..f942a8e1 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
"text", // model_mapping_chain
"text", // billing_tier
"text", // billing_mode
+ "numeric", // account_stats_cost
"timestamptz", // created_at
}
@@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
@@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
- $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
) AS (VALUES `)
@@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
)
SELECT
@@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
) AS (VALUES `)
- args := make([]any, 0, len(preparedList)*45)
+ args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
)
SELECT
@@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
model_mapping_chain,
billing_tier,
billing_mode,
+ account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
@@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
- $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
modelMappingChain,
billingTier,
billingMode,
+ log.AccountStatsCost, // account_stats_cost
createdAt,
},
}
@@ -1959,7 +1969,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -1989,7 +1999,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2026,7 +2036,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2990,7 +3000,7 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
modelExpr := resolveModelDimensionExpression(source)
@@ -3358,7 +3368,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
@@ -3433,7 +3443,7 @@ type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3500,7 +3510,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3591,7 +3601,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
@@ -4069,6 +4079,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
modelMappingChain sql.NullString
billingTier sql.NullString
billingMode sql.NullString
+ accountStatsCost sql.NullFloat64
createdAt time.Time
)
@@ -4118,6 +4129,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&modelMappingChain,
&billingTier,
&billingMode,
+ &accountStatsCost,
&createdAt,
); err != nil {
return nil, err
@@ -4214,6 +4226,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if billingMode.Valid {
log.BillingMode = &billingMode.String
}
+ if accountStatsCost.Valid {
+ log.AccountStatsCost = &accountStatsCost.Float64
+ }
return log, nil
}
diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go
index b9cb6a13..acdd6e62 100644
--- a/backend/internal/repository/usage_log_repo_request_type_test.go
+++ b/backend/internal/repository/usage_log_repo_request_type_test.go
@@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
+ sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
@@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
+ sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
@@ -483,10 +485,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
- sql.NullInt64{}, // channel_id
- sql.NullString{}, // model_mapping_chain
- sql.NullString{}, // billing_tier
- sql.NullString{}, // billing_mode
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -530,10 +533,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
- sql.NullInt64{}, // channel_id
- sql.NullString{}, // model_mapping_chain
- sql.NullString{}, // billing_tier
- sql.NullString{}, // billing_mode
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -577,10 +581,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
- sql.NullInt64{}, // channel_id
- sql.NullString{}, // model_mapping_chain
- sql.NullString{}, // billing_tier
- sql.NullString{}, // billing_mode
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go
new file mode 100644
index 00000000..86f98a12
--- /dev/null
+++ b/backend/internal/service/account_stats_pricing.go
@@ -0,0 +1,192 @@
+package service
+
+import (
+ "context"
+ "sort"
+ "strings"
+)
+
+// resolveAccountStatsCost 计算账号统计定价费用。
+// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
+//
+// 匹配优先级(先命中为准):
+// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历)
+// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时)
+// 3. nil → 走默认公式
+func resolveAccountStatsCost(
+ ctx context.Context,
+ channelService *ChannelService,
+ billingService *BillingService,
+ accountID int64,
+ groupID int64,
+ billingModel string,
+ tokens UsageTokens,
+ requestCount int,
+ serviceTier string,
+) *float64 {
+ if channelService == nil || billingService == nil {
+ return nil
+ }
+ channel, err := channelService.GetChannelForGroup(ctx, groupID)
+ if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
+ return nil
+ }
+
+ platform := channelService.GetGroupPlatform(ctx, groupID)
+ modelLower := strings.ToLower(billingModel)
+
+ // 优先级 1:自定义规则
+ if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil {
+ return cost
+ }
+
+ // 优先级 2:渠道已有模型定价
+ return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount)
+}
+
+// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
+func tryCustomRules(
+ channel *Channel, accountID, groupID int64,
+ platform, modelLower string, tokens UsageTokens, requestCount int,
+) *float64 {
+ for _, rule := range channel.AccountStatsPricingRules {
+ if !matchAccountStatsRule(&rule, accountID, groupID) {
+ continue
+ }
+ pricing := findPricingForModel(rule.Pricing, platform, modelLower)
+ if pricing == nil {
+ continue // 规则匹配但模型不在规则定价中,继续下一条
+ }
+ return calculateStatsCost(pricing, tokens, requestCount)
+ }
+ return nil
+}
+
+// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。
+func tryChannelPricing(
+ ctx context.Context, channelService *ChannelService,
+ groupID int64, billingModel string, tokens UsageTokens, requestCount int,
+) *float64 {
+ pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel)
+ if pricing == nil {
+ return nil
+ }
+ return calculateStatsCost(pricing, tokens, requestCount)
+}
+
+// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
+// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
+// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
+func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
+ if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
+ return false
+ }
+ for _, id := range rule.AccountIDs {
+ if id == accountID {
+ return true
+ }
+ }
+ for _, id := range rule.GroupIDs {
+ if id == groupID {
+ return true
+ }
+ }
+ return false
+}
+
+// wildcardMatch 通配符匹配候选项(用于排序)
+type wildcardMatch struct {
+ prefixLen int
+ pricing *ChannelModelPricing
+}
+
+// findPricingForModel 在定价列表中查找匹配的模型定价。
+// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
+func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
+ // 精确匹配优先
+ for i := range pricingList {
+ p := &pricingList[i]
+ if !isPlatformMatch(platform, p.Platform) {
+ continue
+ }
+ for _, m := range p.Models {
+ if strings.ToLower(m) == modelLower {
+ return p
+ }
+ }
+ }
+ // 通配符匹配:收集所有匹配项,按前缀长度降序取最长
+ var matches []wildcardMatch
+ for i := range pricingList {
+ p := &pricingList[i]
+ if !isPlatformMatch(platform, p.Platform) {
+ continue
+ }
+ for _, m := range p.Models {
+ ml := strings.ToLower(m)
+ if !strings.HasSuffix(ml, "*") {
+ continue
+ }
+ prefix := strings.TrimSuffix(ml, "*")
+ if strings.HasPrefix(modelLower, prefix) {
+ matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
+ }
+ }
+ }
+ if len(matches) == 0 {
+ return nil
+ }
+ sort.Slice(matches, func(i, j int) bool {
+ return matches[i].prefixLen > matches[j].prefixLen
+ })
+ return matches[0].pricing
+}
+
+// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
+func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
+ if queryPlatform == "" || pricingPlatform == "" {
+ return true
+ }
+ return queryPlatform == pricingPlatform
+}
+
+// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
+func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
+ if pricing == nil {
+ return nil
+ }
+ switch pricing.BillingMode {
+ case BillingModePerRequest, BillingModeImage:
+ return calculatePerRequestStatsCost(pricing, requestCount)
+ default:
+ return calculateTokenStatsCost(pricing, tokens)
+ }
+}
+
+// calculatePerRequestStatsCost 按次/图片计费。
+func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
+ if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
+ return nil
+ }
+ cost := *pricing.PerRequestPrice * float64(requestCount)
+ return &cost
+}
+
+// calculateTokenStatsCost Token 计费。
+func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
+ deref := func(p *float64) float64 {
+ if p == nil {
+ return 0
+ }
+ return *p
+ }
+ cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
+ float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
+ float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
+ float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
+ float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
+ if cost == 0 {
+ return nil
+ }
+ return &cost
+}
diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go
new file mode 100644
index 00000000..bc3db251
--- /dev/null
+++ b/backend/internal/service/account_stats_pricing_test.go
@@ -0,0 +1,430 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// matchAccountStatsRule
+// ---------------------------------------------------------------------------
+
+func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{}
+ require.False(t, matchAccountStatsRule(rule, 1, 10))
+}
+
+func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
+ require.True(t, matchAccountStatsRule(rule, 2, 999))
+}
+
+func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
+ require.True(t, matchAccountStatsRule(rule, 999, 20))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.True(t, matchAccountStatsRule(rule, 2, 999))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.True(t, matchAccountStatsRule(rule, 999, 10))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.False(t, matchAccountStatsRule(rule, 999, 999))
+}
+
+// ---------------------------------------------------------------------------
+// findPricingForModel
+// ---------------------------------------------------------------------------
+
+func TestFindPricingForModel(t *testing.T) {
+ exactPricing := ChannelModelPricing{
+ ID: 1,
+ Models: []string{"claude-opus-4"},
+ }
+ wildcardPricing := ChannelModelPricing{
+ ID: 2,
+ Models: []string{"claude-*"},
+ }
+ platformPricing := ChannelModelPricing{
+ ID: 3,
+ Platform: "openai",
+ Models: []string{"gpt-4o"},
+ }
+ emptyPlatformPricing := ChannelModelPricing{
+ ID: 4,
+ Models: []string{"gemini-2.5-pro"},
+ }
+
+ tests := []struct {
+ name string
+ list []ChannelModelPricing
+ platform string
+ model string
+ wantID int64
+ wantNil bool
+ }{
+ {
+ name: "exact match",
+ list: []ChannelModelPricing{exactPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 1,
+ },
+ {
+ name: "exact match case insensitive",
+ list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
+ platform: "",
+ model: "claude-opus-4",
+ wantID: 5,
+ },
+ {
+ name: "wildcard match",
+ list: []ChannelModelPricing{wildcardPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 2,
+ },
+ {
+ name: "exact match takes priority over wildcard",
+ list: []ChannelModelPricing{wildcardPricing, exactPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 1,
+ },
+ {
+ name: "platform mismatch skipped",
+ list: []ChannelModelPricing{platformPricing},
+ platform: "anthropic",
+ model: "gpt-4o",
+ wantNil: true,
+ },
+ {
+ name: "empty platform in pricing matches any",
+ list: []ChannelModelPricing{emptyPlatformPricing},
+ platform: "gemini",
+ model: "gemini-2.5-pro",
+ wantID: 4,
+ },
+ {
+ name: "empty platform in query matches any pricing platform",
+ list: []ChannelModelPricing{platformPricing},
+ platform: "",
+ model: "gpt-4o",
+ wantID: 3,
+ },
+ {
+ name: "no match at all",
+ list: []ChannelModelPricing{exactPricing, wildcardPricing},
+ platform: "anthropic",
+ model: "gpt-4o",
+ wantNil: true,
+ },
+ {
+ name: "empty list returns nil",
+ list: nil,
+ model: "claude-opus-4",
+ wantNil: true,
+ },
+ {
+ name: "longer wildcard prefix wins over shorter",
+ list: []ChannelModelPricing{
+ {ID: 10, Models: []string{"claude-*"}},
+ {ID: 11, Models: []string{"claude-opus-*"}},
+ },
+ platform: "",
+ model: "claude-opus-4",
+ wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars)
+ },
+ {
+ name: "shorter wildcard used when longer does not match",
+ list: []ChannelModelPricing{
+ {ID: 10, Models: []string{"claude-*"}},
+ {ID: 11, Models: []string{"claude-opus-*"}},
+ },
+ platform: "",
+ model: "claude-sonnet-4",
+ wantID: 10, // only "claude-*" matches
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := findPricingForModel(tt.list, tt.platform, tt.model)
+ if tt.wantNil {
+ require.Nil(t, result)
+ return
+ }
+ require.NotNil(t, result)
+ require.Equal(t, tt.wantID, result.ID)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// calculateStatsCost
+// ---------------------------------------------------------------------------
+
+func TestCalculateStatsCost_NilPricing(t *testing.T) {
+ result := calculateStatsCost(nil, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_TokenBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ CacheWritePrice: testPtrFloat64(0.003),
+ CacheReadPrice: testPtrFloat64(0.0005),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ CacheReadTokens: 300,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
+ // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
+ require.InDelta(t, 0.95, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ ImageOutputPrice: testPtrFloat64(0.01),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ ImageOutputTokens: 10,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
+ require.InDelta(t, 0.3, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ // OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // Only input contributes: 100*0.001 = 0.1
+ require.InDelta(t, 0.1, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{} // all zeros
+ result := calculateStatsCost(pricing, tokens, 1)
+ // totalCost == 0 → returns nil (does not override, falls back to default formula)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ PerRequestPrice: testPtrFloat64(0.05),
+ }
+ tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
+ result := calculateStatsCost(pricing, tokens, 3)
+ require.NotNil(t, result)
+ // 0.05 * 3 = 0.15
+ require.InDelta(t, 0.15, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ // PerRequestPrice is nil
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ PerRequestPrice: testPtrFloat64(0),
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ // price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_ImageBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ PerRequestPrice: testPtrFloat64(0.10),
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 2)
+ require.NotNil(t, result)
+ // 0.10 * 2 = 0.20
+ require.InDelta(t, 0.20, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ // PerRequestPrice is nil
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
+ // BillingMode is empty string (default) → falls into token billing
+ pricing := &ChannelModelPricing{
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+// ---------------------------------------------------------------------------
+// tryCustomRules — 多规则顺序测试
+// ---------------------------------------------------------------------------
+
+func TestTryCustomRules_FirstMatchWins(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
+ },
+ },
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ // 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
+ require.InDelta(t, 2.0, *result, 1e-12)
+}
+
+func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ AccountIDs: []int64{888}, // 不匹配
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
+ },
+ },
+ {
+ GroupIDs: []int64{1}, // 匹配
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ // 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
+ require.InDelta(t, 5.0, *result, 1e-12)
+}
+
+func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ AccountIDs: []int64{888},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
+ require.Nil(t, result) // 账号和分组都不匹配
+}
+
+func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
+ },
+ },
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
+}
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index baf5c839..3867f2a0 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -49,21 +49,25 @@ type Channel struct {
ModelPricing []ChannelModelPricing
// 渠道级模型映射(按平台分组:platform → {src→dst})
ModelMapping map[string]map[string]string
- // 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
- FeaturesConfig map[string]any
+
+ // 账号统计定价
+ ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计
+ AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
}
-// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
-func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
- if c == nil || c.FeaturesConfig == nil {
- return false
- }
- wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
- if !ok {
- return false
- }
- enabled, ok := wse[platform].(bool)
- return ok && enabled
+// AccountStatsPricingRule 账号统计定价规则
+// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
+// 多条规则按 SortOrder 排序,先命中为准。
+type AccountStatsPricingRule struct {
+ ID int64
+ ChannelID int64
+ Name string
+ GroupIDs []int64
+ AccountIDs []int64
+ SortOrder int
+ Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构)
+ CreatedAt time.Time
+ UpdatedAt time.Time
}
// ChannelModelPricing 渠道模型定价条目
@@ -192,6 +196,26 @@ func (c *Channel) Clone() *Channel {
cp.ModelMapping[platform] = inner
}
}
+ if c.AccountStatsPricingRules != nil {
+ cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
+ for i, rule := range c.AccountStatsPricingRules {
+ cp.AccountStatsPricingRules[i] = rule
+ if rule.GroupIDs != nil {
+ cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs))
+ copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs)
+ }
+ if rule.AccountIDs != nil {
+ cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs))
+ copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs)
+ }
+ if rule.Pricing != nil {
+ cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing))
+ for j := range rule.Pricing {
+ cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone()
+ }
+ }
+ }
+ }
return &cp
}
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index 7b28662b..d0698f0f 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -416,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
return ch.Clone(), nil
}
+// GetGroupPlatform 获取分组的平台标识(从缓存)
+func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string {
+ cache, err := s.loadCache(ctx)
+ if err != nil {
+ return ""
+ }
+ return cache.groupPlatform[groupID]
+}
+
// channelLookup 热路径公共查找结果
type channelLookup struct {
cache *channelCache
@@ -656,16 +665,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
}
channel := &Channel{
- Name: input.Name,
- Description: input.Description,
- Status: StatusActive,
- BillingModelSource: input.BillingModelSource,
- RestrictModels: input.RestrictModels,
- GroupIDs: input.GroupIDs,
- ModelPricing: input.ModelPricing,
- ModelMapping: input.ModelMapping,
- Features: input.Features,
- FeaturesConfig: input.FeaturesConfig,
+ Name: input.Name,
+ Description: input.Description,
+ Status: StatusActive,
+ BillingModelSource: input.BillingModelSource,
+ RestrictModels: input.RestrictModels,
+ GroupIDs: input.GroupIDs,
+ ModelPricing: input.ModelPricing,
+ ModelMapping: input.ModelMapping,
+ Features: input.Features,
+ ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
+ AccountStatsPricingRules: input.AccountStatsPricingRules,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
@@ -754,8 +764,11 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
- if input.FeaturesConfig != nil {
- channel.FeaturesConfig = input.FeaturesConfig
+ if input.ApplyPricingToAccountStats != nil {
+ channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats
+ }
+ if input.AccountStatsPricingRules != nil {
+ channel.AccountStatsPricingRules = *input.AccountStatsPricingRules
}
return nil
}
@@ -922,27 +935,29 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
// CreateChannelInput 创建渠道输入
type CreateChannelInput struct {
- Name string
- Description string
- GroupIDs []int64
- ModelPricing []ChannelModelPricing
- ModelMapping map[string]map[string]string // platform → {src→dst}
- BillingModelSource string
- RestrictModels bool
- Features string
- FeaturesConfig map[string]any
+ Name string
+ Description string
+ GroupIDs []int64
+ ModelPricing []ChannelModelPricing
+ ModelMapping map[string]map[string]string // platform → {src→dst}
+ BillingModelSource string
+ RestrictModels bool
+ Features string
+ ApplyPricingToAccountStats bool
+ AccountStatsPricingRules []AccountStatsPricingRule
}
// UpdateChannelInput 更新渠道输入
type UpdateChannelInput struct {
- Name string
- Description *string
- Status string
- GroupIDs *[]int64
- ModelPricing *[]ChannelModelPricing
- ModelMapping map[string]map[string]string // platform → {src→dst}
- BillingModelSource string
- RestrictModels *bool
- Features *string
- FeaturesConfig map[string]any
+ Name string
+ Description *string
+ Status string
+ GroupIDs *[]int64
+ ModelPricing *[]ChannelModelPricing
+ ModelMapping map[string]map[string]string // platform → {src→dst}
+ BillingModelSource string
+ RestrictModels *bool
+ Features *string
+ ApplyPricingToAccountStats *bool
+ AccountStatsPricingRules *[]AccountStatsPricingRule
}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 77e9b8c8..1d6d0a08 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -7559,6 +7559,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
+ // 计算账号统计定价费用
+ if apiKey.GroupID != nil {
+ usageLog.AccountStatsCost = resolveAccountStatsCost(
+ ctx, s.channelService, s.billingService,
+ account.ID, *apiKey.GroupID, billingModel,
+ UsageTokens{
+ InputTokens: result.Usage.InputTokens,
+ OutputTokens: result.Usage.OutputTokens,
+ CacheCreationTokens: result.Usage.CacheCreationInputTokens,
+ CacheReadTokens: result.Usage.CacheReadInputTokens,
+ ImageOutputTokens: result.Usage.ImageOutputTokens,
+ },
+ 1, // requestCount
+ "", // serviceTier: Anthropic 平台不使用 service tier
+ )
+ }
+
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index dbc53869..3daa8756 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -4569,6 +4569,15 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
+ // 计算账号统计定价费用
+ if apiKey.GroupID != nil {
+ usageLog.AccountStatsCost = resolveAccountStatsCost(
+ ctx, s.channelService, s.billingService,
+ account.ID, *apiKey.GroupID, billingModel,
+ tokens, 1, serviceTier,
+ )
+ }
+
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go
index 3218f3db..e29d282e 100644
--- a/backend/internal/service/usage_log.go
+++ b/backend/internal/service/usage_log.go
@@ -146,6 +146,8 @@ type UsageLog struct {
RateMultiplier float64
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
AccountRateMultiplier *float64
+ // AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
+ AccountStatsCost *float64
BillingType int8
RequestType RequestType
diff --git a/backend/migrations/101_add_account_stats_pricing.sql b/backend/migrations/101_add_account_stats_pricing.sql
new file mode 100644
index 00000000..a61d0c26
--- /dev/null
+++ b/backend/migrations/101_add_account_stats_pricing.sql
@@ -0,0 +1,38 @@
+-- Account statistics pricing: allow channels to configure custom pricing for account cost tracking.
+
+-- 1. Channel-level toggle
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS apply_pricing_to_account_stats BOOLEAN NOT NULL DEFAULT FALSE;
+
+-- 2. Account stats pricing rules (ordered list per channel)
+CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_rules (
+ id BIGSERIAL PRIMARY KEY,
+ channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
+ name VARCHAR(100) NOT NULL DEFAULT '',
+ group_ids BIGINT[] NOT NULL DEFAULT '{}',
+ account_ids BIGINT[] NOT NULL DEFAULT '{}',
+ sort_order INT NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+CREATE INDEX IF NOT EXISTS idx_cas_pricing_rules_channel_id ON channel_account_stats_pricing_rules(channel_id);
+
+-- 3. Model pricing for each rule (same structure as channel_model_pricing)
+CREATE TABLE IF NOT EXISTS channel_account_stats_model_pricing (
+ id BIGSERIAL PRIMARY KEY,
+ rule_id BIGINT NOT NULL REFERENCES channel_account_stats_pricing_rules(id) ON DELETE CASCADE,
+ platform VARCHAR(50) NOT NULL DEFAULT '',
+ models JSONB NOT NULL DEFAULT '[]',
+ billing_mode VARCHAR(20) NOT NULL DEFAULT 'token',
+ input_price NUMERIC(20,10),
+ output_price NUMERIC(20,10),
+ cache_write_price NUMERIC(20,10),
+ cache_read_price NUMERIC(20,10),
+ image_output_price NUMERIC(20,10),
+ per_request_price NUMERIC(20,10),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+CREATE INDEX IF NOT EXISTS idx_cas_model_pricing_rule_id ON channel_account_stats_model_pricing(rule_id);
+
+-- 4. Usage logs: pre-computed account stats cost (NULL = use default formula)
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS account_stats_cost NUMERIC(20,10);
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
index d49982aa..a13eb3e1 100644
--- a/frontend/src/api/admin/channels.ts
+++ b/frontend/src/api/admin/channels.ts
@@ -34,6 +34,14 @@ export interface ChannelModelPricing {
intervals: PricingInterval[]
}
+export interface AccountStatsPricingRule {
+ id?: number
+ name: string
+ group_ids: number[]
+ account_ids: number[]
+ pricing: ChannelModelPricing[]
+}
+
export interface Channel {
id: number
name: string
@@ -41,10 +49,11 @@ export interface Channel {
status: string
billing_model_source: string // "requested" | "upstream"
restrict_models: boolean
- features_config?: Record
group_ids: number[]
model_pricing: ChannelModelPricing[]
model_mapping: Record> // platform → {src→dst}
+ apply_pricing_to_account_stats: boolean
+ account_stats_pricing_rules: AccountStatsPricingRule[]
created_at: string
updated_at: string
}
@@ -57,7 +66,8 @@ export interface CreateChannelRequest {
model_mapping?: Record>
billing_model_source?: string
restrict_models?: boolean
- features_config?: Record
+ apply_pricing_to_account_stats?: boolean
+ account_stats_pricing_rules?: AccountStatsPricingRule[]
}
export interface UpdateChannelRequest {
@@ -69,7 +79,8 @@ export interface UpdateChannelRequest {
model_mapping?: Record>
billing_model_source?: string
restrict_models?: boolean
- features_config?: Record
+ apply_pricing_to_account_stats?: boolean
+ account_stats_pricing_rules?: AccountStatsPricingRule[]
}
interface PaginatedResponse {
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 99f8d535..dd45ea17 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -1844,7 +1844,18 @@ export default {
noPlatforms: 'Click "Add Platform" to start configuring the channel',
mappingCount: 'mappings',
pricingEntry: 'Pricing Entry',
- noModels: 'No models added'
+ noModels: 'No models added',
+ applyPricingToAccountStats: 'Apply Pricing to Account Stats',
+ applyPricingToAccountStatsDesc: 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.',
+ accountStatsPricingRules: 'Custom Account Stats Pricing Rules',
+ addRule: 'Add Rule',
+ noRulesConfigured: 'No custom rules configured. Channel model pricing above will be used.',
+ ruleName: 'Rule name (optional)',
+ ruleGroups: 'Groups',
+ ruleAccounts: 'Account IDs',
+ ruleAccountsPlaceholder: 'Enter account IDs, comma-separated',
+ ruleModelPricing: 'Model Pricing',
+ noGroupsInChannel: 'No groups selected in platform tabs above'
}
},
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 7ef7ead0..bbfc7971 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -1923,7 +1923,18 @@ export default {
noPlatforms: '点击"添加平台"开始配置渠道',
mappingCount: '条映射',
pricingEntry: '定价配置',
- noModels: '未添加模型'
+ noModels: '未添加模型',
+ applyPricingToAccountStats: '应用模型定价到账号统计',
+ applyPricingToAccountStatsDesc: '启用后,账号统计费用将使用渠道模型定价计算。账号自身的统计倍率仍然生效。',
+ accountStatsPricingRules: '自定义账号统计定价规则',
+ addRule: '添加规则',
+ noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。',
+ ruleName: '规则名称(可选)',
+ ruleGroups: '分组',
+ ruleAccounts: '账号 ID',
+ ruleAccountsPlaceholder: '输入账号 ID,逗号分隔',
+ ruleModelPricing: '模型定价',
+ noGroupsInChannel: '上方平台标签页中未选择分组'
}
},
diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue
index ce8d3c9c..a49e1694 100644
--- a/frontend/src/views/admin/ChannelsView.vue
+++ b/frontend/src/views/admin/ChannelsView.vue
@@ -306,24 +306,6 @@
-
-
-
-
-
- {{ t('admin.channels.form.webSearchEmulation') }}
-
-
- {{ t('admin.channels.form.webSearchEmulationHint') }}
-
-
- {{ t('admin.channels.form.webSearchEmulationGlobalDisabled') }}
-
-
-
-
-
-
+
+
+
+
+
+
+
+ {{ t('admin.channels.form.applyPricingToAccountStats', 'Apply Pricing to Account Stats') }}
+
+
+ {{ t('admin.channels.form.applyPricingToAccountStatsDesc', 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.') }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.channels.form.accountStatsPricingRules', 'Custom Account Stats Pricing Rules') }}
+
+
+ + {{ t('admin.channels.form.addRule', 'Add Rule') }}
+
+
+
+
+ {{ t('admin.channels.form.noRulesConfigured', 'No custom rules configured. Channel model pricing above will be used.') }}
+
+
+
+
+
+
+
+ {{ t('common.delete', 'Delete') }}
+
+
+
+
+
+
+ {{ t('admin.channels.form.ruleGroups', 'Groups') }}
+
+
+
+
+ {{ getGroupNameById(gid) }}
+
+
+
+ {{ t('admin.channels.form.noGroupsInChannel', 'No groups selected in platform tabs above') }}
+
+
+
+
+
+
+ {{ t('admin.channels.form.ruleAccounts', 'Account IDs') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.channels.form.ruleModelPricing', 'Model Pricing') }}
+
+
+ + {{ t('common.add', 'Add') }}
+
+
+
+ {{ t('admin.channels.form.noPricingRules', 'No pricing rules yet. Click "Add" to create one.') }}
+
+
+
+
+
+
@@ -441,9 +560,8 @@
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
-import { extractApiErrorMessage } from '@/utils/apiError'
import { adminAPI } from '@/api/admin'
-import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
+import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest, AccountStatsPricingRule } from '@/api/admin/channels'
import type { PricingFormEntry } from '@/components/admin/channel/types'
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
import type { AdminGroup, GroupPlatform } from '@/types'
@@ -465,18 +583,6 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
const { t } = useI18n()
const appStore = useAppStore()
-// Web Search global enabled state (loaded once on mount)
-const webSearchGlobalEnabled = ref(false)
-async function loadWebSearchGlobalState() {
- try {
- const cfg = await adminAPI.settings.getWebSearchEmulationConfig()
- webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
- } catch (err: unknown) {
- console.warn('Failed to load web search global state:', err)
- webSearchGlobalEnabled.value = false
- }
-}
-
// ── Platform Section type ──
interface PlatformSection {
platform: GroupPlatform
@@ -485,7 +591,6 @@ interface PlatformSection {
group_ids: number[]
model_mapping: Record
model_pricing: PricingFormEntry[]
- web_search_emulation: boolean
}
// ── Table columns ──
@@ -553,7 +658,14 @@ const form = reactive({
status: 'active',
restrict_models: false,
billing_model_source: 'channel_mapped' as string,
- platforms: [] as PlatformSection[]
+ platforms: [] as PlatformSection[],
+ apply_pricing_to_account_stats: false,
+ account_stats_pricing_rules: [] as Array<{
+ name: string
+ group_ids: number[]
+ account_ids: number[]
+ pricing: PricingFormEntry[]
+ }>
})
let abortController: AbortController | null = null
@@ -597,8 +709,7 @@ function addPlatformSection(platform: GroupPlatform) {
collapsed: false,
group_ids: [],
model_mapping: {},
- model_pricing: [],
- web_search_emulation: false,
+ model_pricing: []
})
}
@@ -711,15 +822,89 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) {
mapping[newKey] = value
}
+// ── Account Stats Pricing helpers ──
+function addAccountStatsRule() {
+ form.account_stats_pricing_rules.push({
+ name: '',
+ group_ids: [],
+ account_ids: [],
+ pricing: []
+ })
+}
+
+function addRulePricingEntry(ruleIndex: number) {
+ form.account_stats_pricing_rules[ruleIndex].pricing.push({
+ models: [],
+ billing_mode: 'token',
+ input_price: null,
+ output_price: null,
+ cache_write_price: null,
+ cache_read_price: null,
+ image_output_price: null,
+ per_request_price: null,
+ intervals: []
+ })
+}
+
+function removeAccountStatsRule(ruleIndex: number) {
+ form.account_stats_pricing_rules.splice(ruleIndex, 1)
+}
+
+function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) {
+ form.account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1)
+}
+
+function getGroupNameById(groupId: number): string {
+ const group = allGroups.value.find(g => g.id === groupId)
+ return group ? group.name : `#${groupId}`
+}
+
+/** Collect all group_ids from enabled platform sections */
+const allFormGroupIds = computed(() => {
+ const ids = new Set()
+ for (const section of form.platforms) {
+ if (!section.enabled) continue
+ for (const gid of section.group_ids) {
+ ids.add(gid)
+ }
+ }
+ return [...ids]
+})
+
+function parseAccountIdsInput(value: string): number[] {
+ return value
+ .split(',')
+ .map(s => parseInt(s.trim()))
+ .filter(n => !isNaN(n) && n > 0)
+}
+
+function accountStatsRulesToAPI(): AccountStatsPricingRule[] {
+ return form.account_stats_pricing_rules.map(rule => ({
+ name: rule.name,
+ group_ids: rule.group_ids,
+ account_ids: rule.account_ids,
+ pricing: rule.pricing
+ .filter(p => p.models.length > 0)
+ .map(p => ({
+ platform: '',
+ models: p.models,
+ billing_mode: p.billing_mode,
+ input_price: mTokToPerToken(p.input_price),
+ output_price: mTokToPerToken(p.output_price),
+ cache_write_price: mTokToPerToken(p.cache_write_price),
+ cache_read_price: mTokToPerToken(p.cache_read_price),
+ image_output_price: mTokToPerToken(p.image_output_price),
+ per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null,
+ intervals: formIntervalsToAPI(p.intervals || [])
+ }))
+ }))
+}
+
// ── Form ↔ API conversion ──
-function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record>, features_config: Record } {
+function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record> } {
const group_ids: number[] = []
const model_pricing: ChannelModelPricing[] = []
const model_mapping: Record> = {}
- // Preserve existing features_config fields not managed by the form
- const featuresConfig: Record = editingChannel.value?.features_config
- ? { ...editingChannel.value.features_config }
- : {}
for (const section of form.platforms) {
if (!section.enabled) continue
@@ -748,19 +933,7 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
}
}
- // Collect web_search_emulation (only anthropic platform supports it)
- const wsEmulation: Record = {}
- for (const section of form.platforms) {
- if (!section.enabled) continue
- if (section.web_search_emulation && section.platform === 'anthropic') {
- wsEmulation[section.platform] = true
- }
- }
- if (Object.keys(wsEmulation).length > 0) {
- featuresConfig.web_search_emulation = wsEmulation
- }
-
- return { group_ids, model_pricing, model_mapping, features_config: featuresConfig }
+ return { group_ids, model_pricing, model_mapping }
}
function apiToForm(channel: Channel): PlatformSection[] {
@@ -804,19 +977,13 @@ function apiToForm(channel: Channel): PlatformSection[] {
intervals: apiIntervalsToForm(p.intervals || [])
} as PricingFormEntry))
- // Read web_search_emulation from features_config
- const fc = channel.features_config
- const wsEmulation = fc?.web_search_emulation as Record | undefined
- const webSearchEnabled = wsEmulation?.[platform] === true
-
sections.push({
platform,
enabled: true,
collapsed: false,
group_ids: groupIds,
model_mapping: { ...mapping },
- model_pricing: pricing,
- web_search_emulation: webSearchEnabled,
+ model_pricing: pricing
})
}
@@ -841,10 +1008,10 @@ async function loadChannels() {
if (ctrl.signal.aborted || abortController !== ctrl) return
channels.value = response.items || []
pagination.total = response.total
- } catch (error: unknown) {
- const e = error as { name?: string; code?: string }
- if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return
- appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels')))
+ } catch (error: any) {
+ if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return
+ appStore.showError(t('admin.channels.loadError', 'Failed to load channels'))
+ console.error('Error loading channels:', error)
} finally {
if (abortController === ctrl) {
loading.value = false
@@ -909,6 +1076,8 @@ function resetForm() {
form.restrict_models = false
form.billing_model_source = 'channel_mapped'
form.platforms = []
+ form.apply_pricing_to_account_stats = false
+ form.account_stats_pricing_rules = []
activeTab.value = 'basic'
}
@@ -926,6 +1095,23 @@ async function openEditDialog(channel: Channel) {
form.status = channel.status
form.restrict_models = channel.restrict_models || false
form.billing_model_source = channel.billing_model_source || 'channel_mapped'
+ form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false
+ form.account_stats_pricing_rules = (channel.account_stats_pricing_rules || []).map(rule => ({
+ name: rule.name || '',
+ group_ids: [...(rule.group_ids || [])],
+ account_ids: [...(rule.account_ids || [])],
+ pricing: (rule.pricing || []).map(p => ({
+ models: [...(p.models || [])],
+ billing_mode: p.billing_mode,
+ input_price: perTokenToMTok(p.input_price),
+ output_price: perTokenToMTok(p.output_price),
+ cache_write_price: perTokenToMTok(p.cache_write_price),
+ cache_read_price: perTokenToMTok(p.cache_read_price),
+ image_output_price: perTokenToMTok(p.image_output_price),
+ per_request_price: p.per_request_price,
+ intervals: apiIntervalsToForm(p.intervals || [])
+ } as PricingFormEntry))
+ }))
// Must load groups first so apiToForm can map groupID → platform
await Promise.all([loadGroups(), loadAllChannelsForConflict()])
form.platforms = apiToForm(channel)
@@ -1024,7 +1210,7 @@ async function handleSubmit() {
}
}
- const { group_ids, model_pricing, model_mapping, features_config } = formToAPI()
+ const { group_ids, model_pricing, model_mapping } = formToAPI()
submitting.value = true
try {
@@ -1038,7 +1224,8 @@ async function handleSubmit() {
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
billing_model_source: form.billing_model_source,
restrict_models: form.restrict_models,
- features_config,
+ apply_pricing_to_account_stats: form.apply_pricing_to_account_stats,
+ account_stats_pricing_rules: accountStatsRulesToAPI()
}
await adminAPI.channels.update(editingChannel.value.id, req)
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
@@ -1051,17 +1238,20 @@ async function handleSubmit() {
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
billing_model_source: form.billing_model_source,
restrict_models: form.restrict_models,
- features_config,
+ apply_pricing_to_account_stats: form.apply_pricing_to_account_stats,
+ account_stats_pricing_rules: accountStatsRulesToAPI()
}
await adminAPI.channels.create(req)
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
}
closeDialog()
loadChannels()
- } catch (error: unknown) {
- appStore.showError(extractApiErrorMessage(error, editingChannel.value
+ } catch (error: any) {
+ const msg = error.response?.data?.detail || (editingChannel.value
? t('admin.channels.updateError', 'Failed to update channel')
- : t('admin.channels.createError', 'Failed to create channel')))
+ : t('admin.channels.createError', 'Failed to create channel'))
+ appStore.showError(msg)
+ console.error('Error saving channel:', error)
} finally {
submitting.value = false
}
@@ -1099,8 +1289,9 @@ async function confirmDelete() {
showDeleteDialog.value = false
deletingChannel.value = null
loadChannels()
- } catch (error: unknown) {
- appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel')))
+ } catch (error: any) {
+ appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel'))
+ console.error('Error deleting channel:', error)
}
}
@@ -1108,7 +1299,6 @@ async function confirmDelete() {
onMounted(() => {
loadChannels()
loadGroups()
- loadWebSearchGlobalState()
})
onUnmounted(() => {
From fda61b067c1dbf99bb58a59c3dab3769906bb889 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 01:48:06 +0800
Subject: [PATCH 17/88] feat(websearch): proxy failover, timeout,
quota-weighted load balancing
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Use proxyutil.ConfigureTransportProxy for unified proxy protocol support
(HTTP/HTTPS/SOCKS5/SOCKS5H), replacing ad-hoc HTTP-only proxy code
- Proxy errors return ErrProxyUnavailable → gateway triggers account switch
via UpstreamFailoverError instead of fallback to direct connection
- Timeout: proxy dial 3s, TLS handshake 3s, data transfer 60s
- Mark proxy unavailable for 5 minutes in Redis on connectivity failure
- Quota-weighted load balancing: providers with quota_limit>0 are selected
by remaining quota (weighted random); quota_limit=0 providers treated as
0% weight and placed last
---
backend/internal/pkg/websearch/manager.go | 271 ++++++++++++++----
.../service/gateway_websearch_emulation.go | 8 +
2 files changed, 227 insertions(+), 52 deletions(-)
diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go
index 95da70e4..615014e0 100644
--- a/backend/internal/pkg/websearch/manager.go
+++ b/backend/internal/pkg/websearch/manager.go
@@ -3,8 +3,11 @@ package websearch
import (
"context"
"crypto/tls"
+ "errors"
"fmt"
"log/slog"
+ "math/rand"
+ "net"
"net/http"
"net/url"
"sort"
@@ -12,6 +15,7 @@ import (
"sync"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/redis/go-redis/v9"
)
@@ -30,6 +34,7 @@ type ProviderConfig struct {
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly
ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
+ ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
}
@@ -42,22 +47,30 @@ type Manager struct {
clientCache map[string]*http.Client
}
+// Timeout constants for proxy and search operations.
const (
+ proxyDialTimeout = 3 * time.Second // proxy TCP connection timeout
+ proxyTLSTimeout = 3 * time.Second // TLS handshake timeout
+ searchDataTimeout = 60 * time.Second // response data transfer timeout
+ searchRequestTimeout = searchDataTimeout + proxyDialTimeout
+
quotaKeyPrefix = "websearch:quota:"
- searchRequestTimeout = 30 * time.Second
+ proxyUnavailableKey = "websearch:proxy_unavailable:%d"
+ proxyUnavailableTTL = 5 * time.Minute
quotaTTLBuffer = 24 * time.Hour
maxCachedClients = 100
)
+// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
+// Callers may use this to trigger account switching instead of direct fallback.
+var ErrProxyUnavailable = errors.New("websearch: proxy unavailable")
+
// quotaIncrScript atomically increments the counter and sets TTL on first creation.
-// KEYS[1] = quota key, ARGV[1] = TTL in seconds.
-// Returns the new counter value.
var quotaIncrScript = redis.NewScript(`
local val = redis.call('INCR', KEYS[1])
if val == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
else
- -- Defensive: ensure TTL exists even if a prior EXPIRE failed
local ttl = redis.call('TTL', KEYS[1])
if ttl == -1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
@@ -80,16 +93,22 @@ func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
}
}
-// SearchWithBestProvider selects the highest-priority available provider,
+// SearchWithBestProvider selects a provider using quota-weighted load balancing,
// reserves quota, executes the search, and rolls back quota on failure.
+// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes.
func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
if strings.TrimSpace(req.Query) == "" {
return nil, "", fmt.Errorf("websearch: empty search query")
}
- for _, cfg := range m.configs {
- if !m.isProviderAvailable(cfg) {
- continue
- }
+
+ candidates := m.filterAvailableProviders(ctx, req.ProxyURL)
+ if len(candidates) == 0 {
+ return nil, "", fmt.Errorf("websearch: no available provider (all exhausted, expired, or proxy unavailable)")
+ }
+
+ selected := m.selectByQuotaWeight(ctx, candidates)
+
+ for _, cfg := range selected {
allowed, incremented := m.tryReserveQuota(ctx, cfg)
if !allowed {
continue
@@ -99,6 +118,12 @@ func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest)
if incremented {
m.rollbackQuota(ctx, cfg)
}
+ if isProxyError(err) {
+ m.markProxyUnavailable(ctx, cfg, req.ProxyURL)
+ slog.Warn("websearch: proxy error, marking unavailable",
+ "provider", cfg.Type, "error", err)
+ return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
+ }
slog.Warn("websearch: provider search failed",
"provider", cfg.Type, "error", err)
continue
@@ -108,6 +133,76 @@ func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest)
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
}
+// filterAvailableProviders returns providers that have API keys, are not expired,
+// and whose proxies are not marked unavailable.
+func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL string) []ProviderConfig {
+ var out []ProviderConfig
+ for _, cfg := range m.configs {
+ if !m.isProviderAvailable(cfg) {
+ continue
+ }
+ proxyID := resolveProxyID(cfg, accountProxyURL)
+ if proxyID > 0 && !m.isProxyAvailable(ctx, proxyID) {
+ slog.Debug("websearch: proxy marked unavailable, skipping",
+ "provider", cfg.Type, "proxy_id", proxyID)
+ continue
+ }
+ out = append(out, cfg)
+ }
+ return out
+}
+
+// selectByQuotaWeight orders candidates by remaining quota weight.
+// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
+// Among providers with quota, higher remaining quota = higher priority.
+func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
+ type weighted struct {
+ cfg ProviderConfig
+ weight int64
+ }
+ items := make([]weighted, 0, len(candidates))
+ for _, cfg := range candidates {
+ w := int64(0)
+ if cfg.QuotaLimit > 0 {
+ used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
+ remaining := cfg.QuotaLimit - used
+ if remaining > 0 {
+ w = remaining
+ }
+ }
+ items = append(items, weighted{cfg: cfg, weight: w})
+ }
+
+ // Separate providers with quota (weight > 0) from those without (weight == 0)
+ var withQuota, withoutQuota []weighted
+ for _, item := range items {
+ if item.weight > 0 {
+ withQuota = append(withQuota, item)
+ } else {
+ withoutQuota = append(withoutQuota, item)
+ }
+ }
+
+ // Within quota group: weighted random sort (higher remaining = more likely first)
+ if len(withQuota) > 1 {
+ sort.Slice(withQuota, func(i, j int) bool {
+ wi := float64(withQuota[i].weight) * (0.5 + rand.Float64())
+ wj := float64(withQuota[j].weight) * (0.5 + rand.Float64())
+ return wi > wj
+ })
+ }
+
+ // Build final order: quota providers first, then no-quota providers (original priority order)
+ result := make([]ProviderConfig, 0, len(candidates))
+ for _, item := range withQuota {
+ result = append(result, item.cfg)
+ }
+ for _, item := range withoutQuota {
+ result = append(result, item.cfg)
+ }
+ return result
+}
+
func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
if cfg.APIKey == "" {
return false
@@ -120,26 +215,80 @@ func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
return true
}
-// tryReserveQuota atomically increments the counter via Lua script and checks limit.
-// Returns (allowed, incremented): allowed=true means the request may proceed;
-// incremented=true means the Redis counter was actually incremented (so rollback is needed on failure).
+// --- Proxy availability tracking ---
+
+// markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL.
+func (m *Manager) markProxyUnavailable(ctx context.Context, cfg ProviderConfig, accountProxyURL string) {
+ proxyID := resolveProxyID(cfg, accountProxyURL)
+ if proxyID <= 0 || m.redis == nil {
+ return
+ }
+ key := fmt.Sprintf(proxyUnavailableKey, proxyID)
+ if err := m.redis.Set(ctx, key, "1", proxyUnavailableTTL).Err(); err != nil {
+ slog.Warn("websearch: failed to mark proxy unavailable",
+ "proxy_id", proxyID, "error", err)
+ }
+}
+
+// isProxyAvailable checks whether a proxy is currently marked as unavailable.
+func (m *Manager) isProxyAvailable(ctx context.Context, proxyID int64) bool {
+ if m.redis == nil || proxyID <= 0 {
+ return true
+ }
+ key := fmt.Sprintf(proxyUnavailableKey, proxyID)
+ val, err := m.redis.Get(ctx, key).Result()
+ if err != nil {
+ return true // Redis error → assume available
+ }
+ return val == ""
+}
+
+// resolveProxyID determines the effective proxy ID for a provider+account combination.
+func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 {
+ if accountProxyURL != "" {
+ return 0 // account proxy has no ID in provider config
+ }
+ return cfg.ProxyID
+}
+
+// isProxyError checks whether the error is likely caused by proxy connectivity.
+func isProxyError(err error) bool {
+ if err == nil {
+ return false
+ }
+ var netErr net.Error
+ if errors.As(err, &netErr) && netErr.Timeout() {
+ return true
+ }
+ var opErr *net.OpError
+ if errors.As(err, &opErr) {
+ return true
+ }
+ msg := err.Error()
+ return strings.Contains(msg, "proxy") ||
+ strings.Contains(msg, "SOCKS") ||
+ strings.Contains(msg, "connection refused") ||
+ strings.Contains(msg, "no such host") ||
+ strings.Contains(msg, "i/o timeout")
+}
+
+// --- Quota management ---
+
func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
if cfg.QuotaLimit <= 0 {
- return true, false // unlimited, no INCR
+ return true, false
}
if m.redis == nil {
- slog.Warn("websearch: Redis unavailable, quota check skipped",
- "provider", cfg.Type)
- return true, false // allowed but not incremented
+ slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type)
+ return true, false
}
key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds())
-
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
if err != nil {
slog.Warn("websearch: quota Lua INCR failed, allowing request",
"provider", cfg.Type, "error", err)
- return true, false // allowed but not incremented
+ return true, false
}
if newVal > cfg.QuotaLimit {
if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
@@ -148,12 +297,11 @@ func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool
}
slog.Info("websearch: provider quota exhausted",
"provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
- return false, false // rejected, already rolled back
+ return false, false
}
- return true, true // allowed and incremented
+ return true, true
}
-// rollbackQuota decrements the counter after a search failure.
func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
if cfg.QuotaLimit <= 0 || m.redis == nil {
return
@@ -165,16 +313,64 @@ func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
}
}
+// --- Search execution ---
+
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
proxyURL := cfg.ProxyURL
if req.ProxyURL != "" {
proxyURL = req.ProxyURL
}
- client := m.getOrCreateHTTPClient(proxyURL)
+ client, err := m.getOrCreateHTTPClient(proxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("websearch: %w", err)
+ }
provider := m.buildProvider(cfg, client)
return provider.Search(ctx, req)
}
+// --- HTTP client cache ---
+
+func (m *Manager) getOrCreateHTTPClient(proxyURL string) (*http.Client, error) {
+ m.clientMu.Lock()
+ defer m.clientMu.Unlock()
+
+ if c, ok := m.clientCache[proxyURL]; ok {
+ return c, nil
+ }
+ if len(m.clientCache) >= maxCachedClients {
+ m.clientCache = make(map[string]*http.Client)
+ }
+ c, err := newHTTPClient(proxyURL)
+ if err != nil {
+ return nil, err
+ }
+ m.clientCache[proxyURL] = c
+ return c, nil
+}
+
+// newHTTPClient creates an HTTP client with proper timeout settings.
+// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support
+// (HTTP/HTTPS/SOCKS5/SOCKS5H).
+// Returns error if proxyURL is invalid — never falls back to direct connection.
+func newHTTPClient(proxyURL string) (*http.Client, error) {
+ transport := &http.Transport{
+ TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
+ DialContext: (&net.Dialer{Timeout: proxyDialTimeout}).DialContext,
+ TLSHandshakeTimeout: proxyTLSTimeout,
+ ResponseHeaderTimeout: searchDataTimeout,
+ }
+ if proxyURL != "" {
+ parsed, err := url.Parse(proxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err)
+ }
+ if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
+ return nil, fmt.Errorf("configure proxy: %w", err)
+ }
+ }
+ return &http.Client{Transport: transport, Timeout: searchRequestTimeout}, nil
+}
+
// GetUsage returns the current usage count for the given provider.
func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) {
if m.redis == nil {
@@ -198,35 +394,6 @@ func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
return result
}
-// --- HTTP client cache (bounded) ---
-
-func (m *Manager) getOrCreateHTTPClient(proxyURL string) *http.Client {
- m.clientMu.Lock()
- defer m.clientMu.Unlock()
-
- if c, ok := m.clientCache[proxyURL]; ok {
- return c
- }
- if len(m.clientCache) >= maxCachedClients {
- m.clientCache = make(map[string]*http.Client) // evict all
- }
- c := newHTTPClient(proxyURL)
- m.clientCache[proxyURL] = c
- return c
-}
-
-func newHTTPClient(proxyURL string) *http.Client {
- transport := &http.Transport{
- TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
- }
- if proxyURL != "" {
- if u, err := url.Parse(proxyURL); err == nil {
- transport.Proxy = http.ProxyURL(u)
- }
- }
- return &http.Client{Transport: transport, Timeout: searchRequestTimeout}
-}
-
// --- Provider factory ---
func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
@@ -256,7 +423,7 @@ func periodKey(refreshInterval string) string {
case QuotaRefreshWeekly:
year, week := now.ISOWeek()
return fmt.Sprintf("%d-W%02d", year, week)
- default: // QuotaRefreshMonthly
+ default:
return now.Format("2006-01")
}
}
diff --git a/backend/internal/service/gateway_websearch_emulation.go b/backend/internal/service/gateway_websearch_emulation.go
index fbea96c0..b3a4aa69 100644
--- a/backend/internal/service/gateway_websearch_emulation.go
+++ b/backend/internal/service/gateway_websearch_emulation.go
@@ -3,6 +3,7 @@ package service
import (
"context"
"encoding/json"
+ "errors"
"fmt"
"log/slog"
"net/http"
@@ -147,6 +148,13 @@ func (s *GatewayService) handleWebSearchEmulation(
resp, providerName, err := doWebSearch(ctx, account, query)
if err != nil {
+ // Proxy unavailable → trigger account switch via UpstreamFailoverError
+ if errors.Is(err, websearch.ErrProxyUnavailable) {
+ return nil, &UpstreamFailoverError{
+ StatusCode: http.StatusBadGateway,
+ ResponseBody: []byte(err.Error()),
+ }
+ }
return nil, err
}
From 499159870c4b92c7ab37cdefe6f3ca044369c1c1 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 01:55:00 +0800
Subject: [PATCH 18/88] fix: gofmt websearch manager
---
backend/internal/pkg/websearch/manager.go | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go
index 615014e0..e6334b70 100644
--- a/backend/internal/pkg/websearch/manager.go
+++ b/backend/internal/pkg/websearch/manager.go
@@ -54,11 +54,11 @@ const (
searchDataTimeout = 60 * time.Second // response data transfer timeout
searchRequestTimeout = searchDataTimeout + proxyDialTimeout
- quotaKeyPrefix = "websearch:quota:"
- proxyUnavailableKey = "websearch:proxy_unavailable:%d"
- proxyUnavailableTTL = 5 * time.Minute
- quotaTTLBuffer = 24 * time.Hour
- maxCachedClients = 100
+ quotaKeyPrefix = "websearch:quota:"
+ proxyUnavailableKey = "websearch:proxy_unavailable:%d"
+ proxyUnavailableTTL = 5 * time.Minute
+ quotaTTLBuffer = 24 * time.Hour
+ maxCachedClients = 100
)
// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
@@ -152,14 +152,16 @@ func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL
return out
}
+// weighted is a provider candidate with computed quota weight.
+type weighted struct {
+ cfg ProviderConfig
+ weight int64
+}
+
// selectByQuotaWeight orders candidates by remaining quota weight.
// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
// Among providers with quota, higher remaining quota = higher priority.
func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
- type weighted struct {
- cfg ProviderConfig
- weight int64
- }
items := make([]weighted, 0, len(candidates))
for _, cfg := range candidates {
w := int64(0)
From 60b0fa81ec4913ba3485c23ee6015e83253768d3 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 02:11:50 +0800
Subject: [PATCH 19/88] fix(websearch): improve isProxyError detection and add
manager tests
- Add TLS error detection to isProxyError (RecordHeaderError, handshake)
- Case-insensitive error string matching
- Add 19 unit tests for: isProviderAvailable, resolveProxyID,
isProxyError, isProxyAvailable, selectByQuotaWeight, newHTTPClient
---
backend/internal/pkg/websearch/manager.go | 20 ++-
.../internal/pkg/websearch/manager_test.go | 132 ++++++++++++++++++
2 files changed, 147 insertions(+), 5 deletions(-)
diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go
index e6334b70..7db3d9a2 100644
--- a/backend/internal/pkg/websearch/manager.go
+++ b/backend/internal/pkg/websearch/manager.go
@@ -253,25 +253,35 @@ func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 {
return cfg.ProxyID
}
-// isProxyError checks whether the error is likely caused by proxy connectivity.
+// isProxyError checks whether the error is likely caused by proxy or network connectivity
+// (as opposed to an API-level error from the search provider).
func isProxyError(err error) bool {
if err == nil {
return false
}
+ // Network-level errors (timeout, connection refused, DNS failure)
var netErr net.Error
- if errors.As(err, &netErr) && netErr.Timeout() {
+ if errors.As(err, &netErr) {
return true
}
var opErr *net.OpError
if errors.As(err, &opErr) {
return true
}
- msg := err.Error()
+ // TLS handshake failures (often caused by proxy intercepting/blocking)
+ var tlsErr *tls.RecordHeaderError
+ if errors.As(err, &tlsErr) {
+ return true
+ }
+ // String-based detection for wrapped errors
+ msg := strings.ToLower(err.Error())
return strings.Contains(msg, "proxy") ||
- strings.Contains(msg, "SOCKS") ||
+ strings.Contains(msg, "socks") ||
strings.Contains(msg, "connection refused") ||
strings.Contains(msg, "no such host") ||
- strings.Contains(msg, "i/o timeout")
+ strings.Contains(msg, "i/o timeout") ||
+ strings.Contains(msg, "tls handshake") ||
+ strings.Contains(msg, "certificate")
}
// --- Quota management ---
diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go
index 4387a2ee..d3cd29d6 100644
--- a/backend/internal/pkg/websearch/manager_test.go
+++ b/backend/internal/pkg/websearch/manager_test.go
@@ -3,6 +3,7 @@ package websearch
import (
"context"
"encoding/json"
+ "fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -147,3 +148,134 @@ func TestQuotaRedisKey_Format(t *testing.T) {
key := quotaRedisKey("brave", QuotaRefreshDaily)
require.Contains(t, key, "websearch:quota:brave:")
}
+
+// --- isProviderAvailable ---
+
+func TestIsProviderAvailable_EmptyAPIKey(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: ""}))
+}
+
+func TestIsProviderAvailable_Expired(t *testing.T) {
+ m := NewManager(nil, nil)
+ past := time.Now().Add(-1 * time.Hour).Unix()
+ require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &past}))
+}
+
+func TestIsProviderAvailable_Valid(t *testing.T) {
+ m := NewManager(nil, nil)
+ future := time.Now().Add(1 * time.Hour).Unix()
+ require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &future}))
+ require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k"})) // no expiry
+}
+
+// --- resolveProxyID ---
+
+func TestResolveProxyID_AccountProxyOverrides(t *testing.T) {
+ cfg := ProviderConfig{ProxyID: 42}
+ // account proxy present → return 0 (account proxy has no config-level ID)
+ require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080"))
+ // no account proxy → return provider's proxy ID
+ require.Equal(t, int64(42), resolveProxyID(cfg, ""))
+}
+
+// --- isProxyError ---
+
+func TestIsProxyError_Nil(t *testing.T) {
+ require.False(t, isProxyError(nil))
+}
+
+func TestIsProxyError_ConnectionRefused(t *testing.T) {
+ err := fmt.Errorf("dial tcp: connection refused")
+ require.True(t, isProxyError(err))
+}
+
+func TestIsProxyError_Timeout(t *testing.T) {
+ err := fmt.Errorf("i/o timeout while connecting to proxy")
+ require.True(t, isProxyError(err))
+}
+
+func TestIsProxyError_SOCKS(t *testing.T) {
+ err := fmt.Errorf("socks connect failed")
+ require.True(t, isProxyError(err))
+}
+
+func TestIsProxyError_TLSHandshake(t *testing.T) {
+ err := fmt.Errorf("tls handshake timeout")
+ require.True(t, isProxyError(err))
+}
+
+func TestIsProxyError_APIError_NotProxy(t *testing.T) {
+ err := fmt.Errorf("API rate limit exceeded")
+ require.False(t, isProxyError(err))
+}
+
+// --- isProxyAvailable (nil Redis) ---
+
+func TestIsProxyAvailable_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.True(t, m.isProxyAvailable(context.Background(), 42))
+}
+
+func TestIsProxyAvailable_ZeroID(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.True(t, m.isProxyAvailable(context.Background(), 0))
+}
+
+// --- selectByQuotaWeight ---
+
+func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) {
+ m := NewManager(nil, nil) // nil Redis → GetUsage returns 0
+ candidates := []ProviderConfig{
+ {Type: "brave", APIKey: "k1", QuotaLimit: 0}, // no limit → weight 0
+ {Type: "tavily", APIKey: "k2", QuotaLimit: 100}, // remaining 100
+ }
+ result := m.selectByQuotaWeight(context.Background(), candidates)
+ require.Len(t, result, 2)
+ // tavily (with quota) should come first
+ require.Equal(t, "tavily", result[0].Type)
+ require.Equal(t, "brave", result[1].Type)
+}
+
+func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) {
+ m := NewManager(nil, nil)
+ candidates := []ProviderConfig{
+ {Type: "brave", APIKey: "k1", QuotaLimit: 0},
+ {Type: "tavily", APIKey: "k2", QuotaLimit: 0},
+ }
+ result := m.selectByQuotaWeight(context.Background(), candidates)
+ require.Len(t, result, 2)
+ // both have weight 0, original order preserved
+}
+
+func TestSelectByQuotaWeight_Empty(t *testing.T) {
+ m := NewManager(nil, nil)
+ result := m.selectByQuotaWeight(context.Background(), nil)
+ require.Empty(t, result)
+}
+
+// --- newHTTPClient ---
+
+func TestNewHTTPClient_NoProxy(t *testing.T) {
+ c, err := newHTTPClient("")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
+
+func TestNewHTTPClient_InvalidProxy(t *testing.T) {
+ _, err := newHTTPClient("://bad-url")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid proxy URL")
+}
+
+func TestNewHTTPClient_ValidHTTPProxy(t *testing.T) {
+ c, err := newHTTPClient("http://proxy.example.com:8080")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
+
+func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) {
+ c, err := newHTTPClient("socks5://proxy.example.com:1080")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
From b32d1a2c9fb658d2691aac3711bc47383d9bd714 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 02:48:57 +0800
Subject: [PATCH 20/88] feat(notify): add balance low & account quota
notification system
- User balance low notification: email alert when balance drops below
configurable threshold (user email + verified extra emails)
- Account quota notification: broadcast email to admin-configured
recipients when daily/weekly/total quota usage exceeds alert threshold
- Admin settings: global enable/disable, default threshold, quota
notification email list (Email Settings tab)
- User profile: enable/disable, custom threshold, add/remove extra
notification emails with verification code flow
- Account quota: per-dimension alert toggle and threshold in quota
control card
- Trigger logic: first-crossing only (old >= threshold && new < threshold
for balance; old < threshold && new >= threshold for quota), naturally
prevents duplicate notifications without Redis dedup
---
backend/cmd/server/wire_gen.go | 9 +-
backend/ent/migrate/schema.go | 3 +
backend/ent/mutation.go | 217 +++++++++++-
backend/ent/runtime/runtime.go | 8 +
backend/ent/schema/user.go | 11 +
backend/ent/user.go | 42 ++-
backend/ent/user/user.go | 28 ++
backend/ent/user/where.go | 140 ++++++++
backend/ent/user_create.go | 228 ++++++++++++
backend/ent/user_update.go | 140 ++++++++
backend/go.sum | 10 +
.../internal/handler/admin/setting_handler.go | 64 ++--
backend/internal/handler/dto/mappers.go | 43 ++-
backend/internal/handler/dto/settings.go | 5 +
backend/internal/handler/dto/types.go | 13 +
backend/internal/handler/user_handler.go | 113 +++++-
backend/internal/repository/api_key_repo.go | 41 ++-
backend/internal/repository/email_cache.go | 35 ++
backend/internal/repository/user_repo.go | 23 +-
backend/internal/server/routes/user.go | 8 +
backend/internal/service/account.go | 39 +++
.../service/auth_service_register_test.go | 12 +
.../service/balance_notify_service.go | 328 ++++++++++++++++++
backend/internal/service/domain_constants.go | 9 +-
backend/internal/service/email_service.go | 5 +
.../service/gateway_record_usage_test.go | 1 +
backend/internal/service/gateway_service.go | 39 ++-
.../openai_gateway_record_usage_test.go | 1 +
.../service/openai_gateway_service.go | 14 +-
.../openai_ws_protocol_forward_test.go | 1 +
backend/internal/service/setting_service.go | 38 +-
backend/internal/service/settings_view.go | 10 +-
backend/internal/service/user.go | 5 +
backend/internal/service/user_service.go | 172 ++++++++-
backend/internal/service/user_service_test.go | 12 +-
backend/internal/service/wire.go | 6 +
.../101_add_balance_notify_fields.sql | 4 +
frontend/src/api/admin/settings.ts | 9 +
frontend/src/api/user.ts | 33 +-
.../components/account/EditAccountModal.vue | 84 +++++
.../src/components/account/QuotaLimitCard.vue | 113 +++++-
.../user/profile/ProfileBalanceNotifyCard.vue | 204 +++++++++++
frontend/src/i18n/locales/en.ts | 47 +++
frontend/src/i18n/locales/zh.ts | 47 +++
frontend/src/types/index.ts | 3 +
frontend/src/views/admin/SettingsView.vue | 72 +++-
frontend/src/views/user/ProfileView.vue | 7 +
47 files changed, 2375 insertions(+), 121 deletions(-)
create mode 100644 backend/internal/service/balance_notify_service.go
create mode 100644 backend/migrations/101_add_balance_notify_fields.sql
create mode 100644 frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 0a0cc84b..8c47b2bd 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -68,7 +68,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
- userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
+ userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
@@ -78,7 +78,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
- userHandler := handler.NewUserHandler(userService)
+ userHandler := handler.NewUserHandler(userService, emailService, emailCache)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
@@ -176,9 +176,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
- gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
+ balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository)
+ gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
- openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
+ openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index e947b2e8..4f31883b 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -1078,6 +1078,9 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
+ {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
+ {Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 6b2fa838..cdaf363a 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -28210,6 +28210,10 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
+ balance_notify_enabled *bool
+ balance_notify_threshold *float64
+ addbalance_notify_threshold *float64
+ balance_notify_extra_emails *string
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -28927,6 +28931,148 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
+ m.balance_notify_enabled = &b
+}
+
+// BalanceNotifyEnabled returns the value of the "balance_notify_enabled" field in the mutation.
+func (m *UserMutation) BalanceNotifyEnabled() (r bool, exists bool) {
+ v := m.balance_notify_enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyEnabled returns the old "balance_notify_enabled" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyEnabled: %w", err)
+ }
+ return oldValue.BalanceNotifyEnabled, nil
+}
+
+// ResetBalanceNotifyEnabled resets all changes to the "balance_notify_enabled" field.
+func (m *UserMutation) ResetBalanceNotifyEnabled() {
+ m.balance_notify_enabled = nil
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (m *UserMutation) SetBalanceNotifyThreshold(f float64) {
+ m.balance_notify_threshold = &f
+ m.addbalance_notify_threshold = nil
+}
+
+// BalanceNotifyThreshold returns the value of the "balance_notify_threshold" field in the mutation.
+func (m *UserMutation) BalanceNotifyThreshold() (r float64, exists bool) {
+ v := m.balance_notify_threshold
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyThreshold returns the old "balance_notify_threshold" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyThreshold(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyThreshold is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyThreshold requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyThreshold: %w", err)
+ }
+ return oldValue.BalanceNotifyThreshold, nil
+}
+
+// AddBalanceNotifyThreshold adds f to the "balance_notify_threshold" field.
+func (m *UserMutation) AddBalanceNotifyThreshold(f float64) {
+ if m.addbalance_notify_threshold != nil {
+ *m.addbalance_notify_threshold += f
+ } else {
+ m.addbalance_notify_threshold = &f
+ }
+}
+
+// AddedBalanceNotifyThreshold returns the value that was added to the "balance_notify_threshold" field in this mutation.
+func (m *UserMutation) AddedBalanceNotifyThreshold() (r float64, exists bool) {
+ v := m.addbalance_notify_threshold
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (m *UserMutation) ClearBalanceNotifyThreshold() {
+ m.balance_notify_threshold = nil
+ m.addbalance_notify_threshold = nil
+ m.clearedFields[user.FieldBalanceNotifyThreshold] = struct{}{}
+}
+
+// BalanceNotifyThresholdCleared returns if the "balance_notify_threshold" field was cleared in this mutation.
+func (m *UserMutation) BalanceNotifyThresholdCleared() bool {
+ _, ok := m.clearedFields[user.FieldBalanceNotifyThreshold]
+ return ok
+}
+
+// ResetBalanceNotifyThreshold resets all changes to the "balance_notify_threshold" field.
+func (m *UserMutation) ResetBalanceNotifyThreshold() {
+ m.balance_notify_threshold = nil
+ m.addbalance_notify_threshold = nil
+ delete(m.clearedFields, user.FieldBalanceNotifyThreshold)
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (m *UserMutation) SetBalanceNotifyExtraEmails(s string) {
+ m.balance_notify_extra_emails = &s
+}
+
+// BalanceNotifyExtraEmails returns the value of the "balance_notify_extra_emails" field in the mutation.
+func (m *UserMutation) BalanceNotifyExtraEmails() (r string, exists bool) {
+ v := m.balance_notify_extra_emails
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyExtraEmails returns the old "balance_notify_extra_emails" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyExtraEmails(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyExtraEmails is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyExtraEmails requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyExtraEmails: %w", err)
+ }
+ return oldValue.BalanceNotifyExtraEmails, nil
+}
+
+// ResetBalanceNotifyExtraEmails resets all changes to the "balance_notify_extra_emails" field.
+func (m *UserMutation) ResetBalanceNotifyExtraEmails() {
+ m.balance_notify_extra_emails = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -29501,7 +29647,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 14)
+ fields := make([]string, 0, 17)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29544,6 +29690,15 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.balance_notify_enabled != nil {
+ fields = append(fields, user.FieldBalanceNotifyEnabled)
+ }
+ if m.balance_notify_threshold != nil {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
+ }
+ if m.balance_notify_extra_emails != nil {
+ fields = append(fields, user.FieldBalanceNotifyExtraEmails)
+ }
return fields
}
@@ -29580,6 +29735,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
+ case user.FieldBalanceNotifyEnabled:
+ return m.BalanceNotifyEnabled()
+ case user.FieldBalanceNotifyThreshold:
+ return m.BalanceNotifyThreshold()
+ case user.FieldBalanceNotifyExtraEmails:
+ return m.BalanceNotifyExtraEmails()
}
return nil, false
}
@@ -29617,6 +29778,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
+ case user.FieldBalanceNotifyEnabled:
+ return m.OldBalanceNotifyEnabled(ctx)
+ case user.FieldBalanceNotifyThreshold:
+ return m.OldBalanceNotifyThreshold(ctx)
+ case user.FieldBalanceNotifyExtraEmails:
+ return m.OldBalanceNotifyExtraEmails(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -29724,6 +29891,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
+ case user.FieldBalanceNotifyEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyEnabled(v)
+ return nil
+ case user.FieldBalanceNotifyThreshold:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyThreshold(v)
+ return nil
+ case user.FieldBalanceNotifyExtraEmails:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyExtraEmails(v)
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -29738,6 +29926,9 @@ func (m *UserMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, user.FieldConcurrency)
}
+ if m.addbalance_notify_threshold != nil {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
+ }
return fields
}
@@ -29750,6 +29941,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalance()
case user.FieldConcurrency:
return m.AddedConcurrency()
+ case user.FieldBalanceNotifyThreshold:
+ return m.AddedBalanceNotifyThreshold()
}
return nil, false
}
@@ -29773,6 +29966,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
+ case user.FieldBalanceNotifyThreshold:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddBalanceNotifyThreshold(v)
+ return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
@@ -29790,6 +29990,9 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
+ }
return fields
}
@@ -29813,6 +30016,9 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
+ case user.FieldBalanceNotifyThreshold:
+ m.ClearBalanceNotifyThreshold()
+ return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -29863,6 +30069,15 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
+ case user.FieldBalanceNotifyEnabled:
+ m.ResetBalanceNotifyEnabled()
+ return nil
+ case user.FieldBalanceNotifyThreshold:
+ m.ResetBalanceNotifyThreshold()
+ return nil
+ case user.FieldBalanceNotifyExtraEmails:
+ m.ResetBalanceNotifyExtraEmails()
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 821b7d66..a288f5d9 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -1293,6 +1293,14 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
+ // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
+ userDescBalanceNotifyEnabled := userFields[11].Descriptor()
+ // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
+ user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
+ // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
+ userDescBalanceNotifyExtraEmails := userFields[13].Descriptor()
+ // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
+ user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index af143d38..bdaa4509 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -72,6 +72,17 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+
+ // 余额不足通知
+ field.Bool("balance_notify_enabled").
+ Default(true),
+ field.Float("balance_notify_threshold").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
+ Optional().
+ Nillable(),
+ field.String("balance_notify_extra_emails").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Default("[]"),
}
}
diff --git a/backend/ent/user.go b/backend/ent/user.go
index a0eef2ba..fc4ddb8f 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,6 +45,12 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
+ // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
+ // BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field.
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
+ // BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field.
+ BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -184,13 +190,13 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case user.FieldTotpEnabled:
+ case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled:
values[i] = new(sql.NullBool)
- case user.FieldBalance:
+ case user.FieldBalance, user.FieldBalanceNotifyThreshold:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
@@ -302,6 +308,25 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
+ case user.FieldBalanceNotifyEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyEnabled = value.Bool
+ }
+ case user.FieldBalanceNotifyThreshold:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyThreshold = new(float64)
+ *_m.BalanceNotifyThreshold = value.Float64
+ }
+ case user.FieldBalanceNotifyExtraEmails:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_extra_emails", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyExtraEmails = value.String
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -440,6 +465,17 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
+ builder.WriteString(", ")
+ builder.WriteString("balance_notify_enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
+ builder.WriteString(", ")
+ if v := _m.BalanceNotifyThreshold; v != nil {
+ builder.WriteString("balance_notify_threshold=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("balance_notify_extra_emails=")
+ builder.WriteString(_m.BalanceNotifyExtraEmails)
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index 338518a8..aff37013 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,6 +43,12 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
+ // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
+ FieldBalanceNotifyEnabled = "balance_notify_enabled"
+ // FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database.
+ FieldBalanceNotifyThreshold = "balance_notify_threshold"
+ // FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database.
+ FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -161,6 +167,9 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
+ FieldBalanceNotifyEnabled,
+ FieldBalanceNotifyThreshold,
+ FieldBalanceNotifyExtraEmails,
}
var (
@@ -217,6 +226,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
+ // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
+ DefaultBalanceNotifyEnabled bool
+ // DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field.
+ DefaultBalanceNotifyExtraEmails string
)
// OrderOption defines the ordering options for the User queries.
@@ -297,6 +310,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
+// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
+func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
+}
+
+// ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field.
+func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc()
+}
+
+// ByBalanceNotifyExtraEmails orders the results by the balance_notify_extra_emails field.
+func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index b1d1000f..11a0318f 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
+// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
+func BalanceNotifyEnabled(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ.
+func BalanceNotifyThreshold(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyExtraEmails applies equality check predicate on the "balance_notify_extra_emails" field. It's identical to BalanceNotifyExtraEmailsEQ.
+func BalanceNotifyExtraEmails(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -860,6 +875,131 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
+// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
+func BalanceNotifyEnabledEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyEnabledNEQ applies the NEQ predicate on the "balance_notify_enabled" field.
+func BalanceNotifyEnabledNEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdNEQ applies the NEQ predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdIn applies the In predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyThreshold, vs...))
+}
+
+// BalanceNotifyThresholdNotIn applies the NotIn predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNotIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThreshold, vs...))
+}
+
+// BalanceNotifyThresholdGT applies the GT predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdGT(v float64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdGTE applies the GTE predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdGTE(v float64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdLT applies the LT predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdLT(v float64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdLTE applies the LTE predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdLTE(v float64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdIsNil applies the IsNil predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldBalanceNotifyThreshold))
+}
+
+// BalanceNotifyThresholdNotNil applies the NotNil predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldBalanceNotifyThreshold))
+}
+
+// BalanceNotifyExtraEmailsEQ applies the EQ predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsNEQ applies the NEQ predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsIn applies the In predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyExtraEmails, vs...))
+}
+
+// BalanceNotifyExtraEmailsNotIn applies the NotIn predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyExtraEmails, vs...))
+}
+
+// BalanceNotifyExtraEmailsGT applies the GT predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsGTE applies the GTE predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsLT applies the LT predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsLTE applies the LTE predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsContains applies the Contains predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsHasPrefix applies the HasPrefix predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsHasSuffix applies the HasSuffix predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsEqualFold applies the EqualFold predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsContainsFold applies the ContainsFold predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index 7f1c5df1..955fde72 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -211,6 +211,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
+ _c.mutation.SetBalanceNotifyEnabled(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyEnabled(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate {
+ _c.mutation.SetBalanceNotifyThreshold(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyThreshold(v *float64) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyThreshold(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_c *UserCreate) SetBalanceNotifyExtraEmails(v string) *UserCreate {
+ _c.mutation.SetBalanceNotifyExtraEmails(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -440,6 +482,14 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
+ if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
+ v := user.DefaultBalanceNotifyEnabled
+ _c.mutation.SetBalanceNotifyEnabled(v)
+ }
+ if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
+ v := user.DefaultBalanceNotifyExtraEmails
+ _c.mutation.SetBalanceNotifyExtraEmails(v)
+ }
return nil
}
@@ -503,6 +553,12 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
+ if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
+ return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
+ }
+ if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
+ return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)}
+ }
return nil
}
@@ -586,6 +642,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
+ if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ _node.BalanceNotifyEnabled = value
+ }
+ if value, ok := _c.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ _node.BalanceNotifyThreshold = &value
+ }
+ if value, ok := _c.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ _node.BalanceNotifyExtraEmails = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -988,6 +1056,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyEnabled, v)
+ return u
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyEnabled)
+ return u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyThreshold, v)
+ return u
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyThreshold() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyThreshold)
+ return u
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsert) AddBalanceNotifyThreshold(v float64) *UserUpsert {
+ u.Add(user.FieldBalanceNotifyThreshold, v)
+ return u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsert) ClearBalanceNotifyThreshold() *UserUpsert {
+ u.SetNull(user.FieldBalanceNotifyThreshold)
+ return u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsert) SetBalanceNotifyExtraEmails(v string) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyExtraEmails, v)
+ return u
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyExtraEmails)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1250,6 +1366,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyEnabled(v)
+ })
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyEnabled()
+ })
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThreshold(v)
+ })
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsertOne) AddBalanceNotifyThreshold(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddBalanceNotifyThreshold(v)
+ })
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyThreshold() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThreshold()
+ })
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsertOne) ClearBalanceNotifyThreshold() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearBalanceNotifyThreshold()
+ })
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsertOne) SetBalanceNotifyExtraEmails(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyExtraEmails(v)
+ })
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyExtraEmails()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1678,6 +1850,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyEnabled(v)
+ })
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyEnabled()
+ })
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThreshold(v)
+ })
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) AddBalanceNotifyThreshold(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddBalanceNotifyThreshold(v)
+ })
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyThreshold() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThreshold()
+ })
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) ClearBalanceNotifyThreshold() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearBalanceNotifyThreshold()
+ })
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsertBulk) SetBalanceNotifyExtraEmails(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyExtraEmails(v)
+ })
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyExtraEmails()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 8107c980..823df0b6 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -243,6 +243,61 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
+ _u.mutation.SetBalanceNotifyEnabled(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyEnabled(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate {
+ _u.mutation.ResetBalanceNotifyThreshold()
+ _u.mutation.SetBalanceNotifyThreshold(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyThreshold(*v)
+ }
+ return _u
+}
+
+// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
+func (_u *UserUpdate) AddBalanceNotifyThreshold(v float64) *UserUpdate {
+ _u.mutation.AddBalanceNotifyThreshold(v)
+ return _u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (_u *UserUpdate) ClearBalanceNotifyThreshold() *UserUpdate {
+ _u.mutation.ClearBalanceNotifyThreshold()
+ return _u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_u *UserUpdate) SetBalanceNotifyExtraEmails(v string) *UserUpdate {
+ _u.mutation.SetBalanceNotifyExtraEmails(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -746,6 +801,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
+ _spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if _u.mutation.BalanceNotifyThresholdCleared() {
+ _spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
+ }
+ if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1434,6 +1504,61 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyEnabled(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyEnabled(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne {
+ _u.mutation.ResetBalanceNotifyThreshold()
+ _u.mutation.SetBalanceNotifyThreshold(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyThreshold(*v)
+ }
+ return _u
+}
+
+// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) AddBalanceNotifyThreshold(v float64) *UserUpdateOne {
+ _u.mutation.AddBalanceNotifyThreshold(v)
+ return _u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) ClearBalanceNotifyThreshold() *UserUpdateOne {
+ _u.mutation.ClearBalanceNotifyThreshold()
+ return _u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_u *UserUpdateOne) SetBalanceNotifyExtraEmails(v string) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyExtraEmails(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1967,6 +2092,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
+ _spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if _u.mutation.BalanceNotifyThresholdCleared() {
+ _spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
+ }
+ if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/go.sum b/backend/go.sum
index e4496f2c..9312af63 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -183,6 +183,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
+github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
+github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -218,6 +220,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
+github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -251,6 +255,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
+github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
+github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -280,6 +286,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
+github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -312,6 +320,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
+github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
+github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 031b819a..459eade9 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -175,7 +175,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
- WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
+ AccountQuotaNotifyEmails: settings.AccountQuotaNotifyEmails,
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
@@ -305,6 +307,11 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
+ // Balance low notification
+ BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
+ AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"`
+
// Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"`
@@ -882,6 +889,24 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
+ BalanceLowNotifyEnabled: func() bool {
+ if req.BalanceLowNotifyEnabled != nil {
+ return *req.BalanceLowNotifyEnabled
+ }
+ return previousSettings.BalanceLowNotifyEnabled
+ }(),
+ BalanceLowNotifyThreshold: func() float64 {
+ if req.BalanceLowNotifyThreshold != nil {
+ return *req.BalanceLowNotifyThreshold
+ }
+ return previousSettings.BalanceLowNotifyThreshold
+ }(),
+ AccountQuotaNotifyEmails: func() []string {
+ if req.AccountQuotaNotifyEmails != nil {
+ return *req.AccountQuotaNotifyEmails
+ }
+ return previousSettings.AccountQuotaNotifyEmails
+ }(),
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -1028,6 +1053,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
+ AccountQuotaNotifyEmails: updatedSettings.AccountQuotaNotifyEmails,
PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
@@ -1848,37 +1876,3 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}
-
-// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
-// GET /api/v1/admin/settings/web-search-emulation
-func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
- cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, service.SanitizeWebSearchConfig(cfg))
-}
-
-// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
-// PUT /api/v1/admin/settings/web-search-emulation
-func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
- var cfg service.WebSearchEmulationConfig
- if err := c.ShouldBindJSON(&cfg); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- // Re-read (with sanitized api keys) to return current state
- updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, service.SanitizeWebSearchConfig(updated))
-}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 478600eb..a465c7fb 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -13,16 +13,19 @@ func UserFromServiceShallow(u *service.User) *User {
return nil
}
return &User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- AllowedGroups: u.AllowedGroups,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ AllowedGroups: u.AllowedGroups,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
+ BalanceNotifyEnabled: u.BalanceNotifyEnabled,
+ BalanceNotifyThreshold: u.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: u.BalanceNotifyExtraEmails,
}
}
@@ -322,6 +325,26 @@ func AccountFromServiceShallow(a *service.Account) *Account {
out.QuotaWeeklyResetAt = &v
}
}
+
+ // 配额通知配置
+ if enabled := a.GetQuotaNotifyDailyEnabled(); enabled {
+ out.QuotaNotifyDailyEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 {
+ out.QuotaNotifyDailyThreshold = &threshold
+ }
+ if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled {
+ out.QuotaNotifyWeeklyEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 {
+ out.QuotaNotifyWeeklyThreshold = &threshold
+ }
+ if enabled := a.GetQuotaNotifyTotalEnabled(); enabled {
+ out.QuotaNotifyTotalEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 {
+ out.QuotaNotifyTotalThreshold = &threshold
+ }
}
return out
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 0433d692..e29f72da 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -148,6 +148,11 @@ type SystemSettings struct {
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
+
+ // Balance low notification
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
+ AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"`
}
type DefaultSubscriptionSetting struct {
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index e026ca65..18522868 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -18,6 +18,11 @@ type User struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
+ // 余额不足通知
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+ BalanceNotifyExtraEmails []string `json:"balance_notify_extra_emails"`
+
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
@@ -218,6 +223,14 @@ type Account struct {
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
+ // 配额通知配置
+ QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"`
+ QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"`
+ QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"`
+ QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"`
+ QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"`
+ QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"`
+
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 35862f1c..42463a7a 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -11,13 +11,17 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
- userService *service.UserService
+ userService *service.UserService
+ emailService *service.EmailService
+ emailCache service.EmailCache
}
// NewUserHandler creates a new UserHandler
-func NewUserHandler(userService *service.UserService) *UserHandler {
+func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
return &UserHandler{
- userService: userService,
+ userService: userService,
+ emailService: emailService,
+ emailCache: emailCache,
}
}
@@ -29,7 +33,9 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
- Username *string `json:"username"`
+ Username *string `json:"username"`
+ BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// GetProfile handles getting user profile
@@ -94,7 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
}
svcReq := service.UpdateProfileRequest{
- Username: req.Username,
+ Username: req.Username,
+ BalanceNotifyEnabled: req.BalanceNotifyEnabled,
+ BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
@@ -104,3 +112,98 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, dto.UserFromService(updatedUser))
}
+
+// SendNotifyEmailCodeRequest represents the request to send notify email verification code
+type SendNotifyEmailCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// SendNotifyEmailCode sends verification code to extra notification email
+// POST /api/v1/user/notify-email/send-code
+func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req SendNotifyEmailCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Verification code sent successfully"})
+}
+
+// VerifyNotifyEmailRequest represents the request to verify and add notify email
+type VerifyNotifyEmailRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Code string `json:"code" binding:"required,len=6"`
+}
+
+// VerifyNotifyEmail verifies code and adds email to notification list
+// POST /api/v1/user/notify-email/verify
+func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req VerifyNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Return updated user
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.UserFromService(updatedUser))
+}
+
+// RemoveNotifyEmailRequest represents the request to remove a notify email
+type RemoveNotifyEmailRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// RemoveNotifyEmail removes email from notification list
+// DELETE /api/v1/user/notify-email
+func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req RemoveNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Email removed successfully"})
+}
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 7fd98855..752a5937 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
+ "encoding/json"
"fmt"
"strings"
"time"
@@ -639,22 +640,32 @@ func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
- return &service.User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Notes: u.Notes,
- PasswordHash: u.PasswordHash,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- TotpSecretEncrypted: u.TotpSecretEncrypted,
- TotpEnabled: u.TotpEnabled,
- TotpEnabledAt: u.TotpEnabledAt,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ out := &service.User{
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Notes: u.Notes,
+ PasswordHash: u.PasswordHash,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ TotpSecretEncrypted: u.TotpSecretEncrypted,
+ TotpEnabled: u.TotpEnabled,
+ TotpEnabledAt: u.TotpEnabledAt,
+ BalanceNotifyEnabled: u.BalanceNotifyEnabled,
+ BalanceNotifyThreshold: u.BalanceNotifyThreshold,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
}
+ // Parse extra emails JSON array
+ if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" {
+ var emails []string
+ if err := json.Unmarshal([]byte(u.BalanceNotifyExtraEmails), &emails); err == nil {
+ out.BalanceNotifyExtraEmails = emails
+ }
+ }
+ return out
}
func groupEntityToService(g *dbent.Group) *service.Group {
diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go
index 8f2b8eca..63552ab0 100644
--- a/backend/internal/repository/email_cache.go
+++ b/backend/internal/repository/email_cache.go
@@ -11,6 +11,7 @@ import (
const (
verifyCodeKeyPrefix = "verify_code:"
+ notifyVerifyKeyPrefix = "notify_verify:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
)
@@ -20,6 +21,11 @@ func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
+// notifyVerifyKey generates the Redis key for notify email verification code.
+func notifyVerifyKey(email string) string {
+ return notifyVerifyKeyPrefix + email
+}
+
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email
@@ -106,3 +112,32 @@ func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email st
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}
+
+// Notify email verification code methods
+
+func (c *emailCache) GetNotifyVerifyCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
+ key := notifyVerifyKey(email)
+ val, err := c.rdb.Get(ctx, key).Result()
+ if err != nil {
+ return nil, err
+ }
+ var data service.VerificationCodeData
+ if err := json.Unmarshal([]byte(val), &data); err != nil {
+ return nil, err
+ }
+ return &data, nil
+}
+
+func (c *emailCache) SetNotifyVerifyCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
+ key := notifyVerifyKey(email)
+ val, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ return c.rdb.Set(ctx, key, val, ttl).Err()
+}
+
+func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
+ key := notifyVerifyKey(email)
+ return c.rdb.Del(ctx, key).Err()
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index d5a13607..2c544857 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
+ "encoding/json"
"errors"
"fmt"
"sort"
@@ -137,7 +138,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient = r.client
}
- updated, err := txClient.User.UpdateOneID(userIn.ID).
+ updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
@@ -146,7 +147,13 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
- Save(ctx)
+ SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
+ SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
+ SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails))
+ if userIn.BalanceNotifyThreshold == nil {
+ updateOp = updateOp.ClearBalanceNotifyThreshold()
+ }
+ updated, err := updateOp.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
@@ -549,6 +556,18 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.UpdatedAt = src.UpdatedAt
}
+// marshalExtraEmails serializes a string slice to JSON for storage.
+func marshalExtraEmails(emails []string) string {
+ if len(emails) == 0 {
+ return "[]"
+ }
+ data, err := json.Marshal(emails)
+ if err != nil {
+ return "[]"
+ }
+ return string(data)
+}
+
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index c3b82742..088565fa 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -26,6 +26,14 @@ func RegisterUserRoutes(
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+ // 通知邮箱管理
+ notifyEmail := user.Group("/notify-email")
+ {
+ notifyEmail.POST("/send-code", h.User.SendNotifyEmailCode)
+ notifyEmail.POST("/verify", h.User.VerifyNotifyEmail)
+ notifyEmail.DELETE("", h.User.RemoveNotifyEmail)
+ }
+
// TOTP 双因素认证
totp := user.Group("/totp")
{
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 582b136c..0b225dac 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -1406,6 +1406,19 @@ func (a *Account) getExtraTime(key string) time.Time {
return time.Time{}
}
+// getExtraBool 从 Extra 中读取指定 key 的 bool 值
+func (a *Account) getExtraBool(key string) bool {
+ if a.Extra == nil {
+ return false
+ }
+ if v, ok := a.Extra[key]; ok {
+ if b, ok := v.(bool); ok {
+ return b
+ }
+ }
+ return false
+}
+
// getExtraString 从 Extra 中读取指定 key 的字符串值
func (a *Account) getExtraString(key string) string {
if a.Extra == nil {
@@ -1475,6 +1488,32 @@ func (a *Account) GetQuotaResetTimezone() string {
return "UTC"
}
+// --- Quota Notification Getters ---
+
+func (a *Account) GetQuotaNotifyDailyEnabled() bool {
+ return a.getExtraBool("quota_notify_daily_enabled")
+}
+
+func (a *Account) GetQuotaNotifyDailyThreshold() float64 {
+ return a.getExtraFloat64("quota_notify_daily_threshold")
+}
+
+func (a *Account) GetQuotaNotifyWeeklyEnabled() bool {
+ return a.getExtraBool("quota_notify_weekly_enabled")
+}
+
+func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 {
+ return a.getExtraFloat64("quota_notify_weekly_threshold")
+}
+
+func (a *Account) GetQuotaNotifyTotalEnabled() bool {
+ return a.getExtraBool("quota_notify_total_enabled")
+}
+
+func (a *Account) GetQuotaNotifyTotalThreshold() float64 {
+ return a.getExtraFloat64("quota_notify_total_threshold")
+}
+
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 7b50e90d..0999b4f0 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil
}
+func (s *emailCacheStub) GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *emailCacheStub) SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
+ return nil
+}
+
+func (s *emailCacheStub) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
+ return nil
+}
+
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
new file mode 100644
index 00000000..7cd61a0a
--- /dev/null
+++ b/backend/internal/service/balance_notify_service.go
@@ -0,0 +1,328 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ emailSendTimeout = 30 * time.Second
+
+ // Quota dimension labels
+ quotaDimDaily = "daily"
+ quotaDimWeekly = "weekly"
+ quotaDimTotal = "total"
+)
+
+// quotaDimLabels maps dimension names to display labels.
+var quotaDimLabels = map[string]string{
+ quotaDimDaily: "日限额 / Daily",
+ quotaDimWeekly: "周限额 / Weekly",
+ quotaDimTotal: "总限额 / Total",
+}
+
+// BalanceNotifyService handles balance and quota threshold notifications.
+type BalanceNotifyService struct {
+ emailService *EmailService
+ settingRepo SettingRepository
+}
+
+// NewBalanceNotifyService creates a new BalanceNotifyService.
+func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService {
+ return &BalanceNotifyService{
+ emailService: emailService,
+ settingRepo: settingRepo,
+ }
+}
+
+// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
+// oldBalance is the balance before deduction, cost is the amount deducted.
+// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
+func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, user *User, oldBalance, cost float64) {
+ if user == nil || s.emailService == nil || s.settingRepo == nil {
+ return
+ }
+
+ // Check user-level switch
+ if !user.BalanceNotifyEnabled {
+ return
+ }
+
+ // Check global switch
+ globalEnabled, threshold := s.getBalanceNotifyConfig(ctx)
+ if !globalEnabled {
+ return
+ }
+
+ // User custom threshold overrides system default
+ if user.BalanceNotifyThreshold != nil {
+ threshold = *user.BalanceNotifyThreshold
+ }
+
+ if threshold <= 0 {
+ return
+ }
+
+ newBalance := oldBalance - cost
+
+ // Only notify on first crossing
+ if oldBalance >= threshold && newBalance < threshold {
+ siteName := s.getSiteName(ctx)
+ recipients := s.collectBalanceNotifyRecipients(user)
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ slog.Error("panic in balance notification", "recover", r)
+ }
+ }()
+ s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName)
+ }()
+ }
+}
+
+// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
+// The account's Extra fields contain pre-increment usage values.
+func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64) {
+ if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
+ return
+ }
+
+ adminEmails := s.getAccountQuotaNotifyEmails(ctx)
+ if len(adminEmails) == 0 {
+ return
+ }
+
+ siteName := s.getSiteName(ctx)
+
+ // Check each dimension
+ type quotaDim struct {
+ name string
+ enabled bool
+ threshold float64
+ oldUsed float64
+ limit float64
+ }
+
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: account.GetQuotaNotifyDailyEnabled(),
+ threshold: account.GetQuotaNotifyDailyThreshold(),
+ oldUsed: account.GetQuotaDailyUsed(),
+ limit: account.GetQuotaDailyLimit(),
+ },
+ {
+ name: quotaDimWeekly,
+ enabled: account.GetQuotaNotifyWeeklyEnabled(),
+ threshold: account.GetQuotaNotifyWeeklyThreshold(),
+ oldUsed: account.GetQuotaWeeklyUsed(),
+ limit: account.GetQuotaWeeklyLimit(),
+ },
+ {
+ name: quotaDimTotal,
+ enabled: account.GetQuotaNotifyTotalEnabled(),
+ threshold: account.GetQuotaNotifyTotalThreshold(),
+ oldUsed: account.GetQuotaUsed(),
+ limit: account.GetQuotaLimit(),
+ },
+ }
+
+ for _, dim := range dims {
+ if !dim.enabled || dim.threshold <= 0 {
+ continue
+ }
+ newUsed := dim.oldUsed + cost
+ // Only notify on first crossing
+ if dim.oldUsed < dim.threshold && newUsed >= dim.threshold {
+ dimCopy := dim // capture loop variable
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ slog.Error("panic in quota notification", "recover", r)
+ }
+ }()
+ s.sendQuotaAlertEmails(adminEmails, account.Name, dimCopy.name, newUsed, dimCopy.limit, dimCopy.threshold, siteName)
+ }()
+ }
+ }
+}
+
+// getBalanceNotifyConfig reads global balance notification settings.
+func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64) {
+ keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold}
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return false, 0
+ }
+ enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
+ if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" {
+ if f, err := strconv.ParseFloat(v, 64); err == nil {
+ threshold = f
+ }
+ }
+ return
+}
+
+// getAccountQuotaNotifyEmails reads admin notification emails from settings.
+func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails)
+ if err != nil || strings.TrimSpace(raw) == "" || raw == "[]" {
+ return nil
+ }
+ return parseJSONStringArray(raw)
+}
+
+// getSiteName reads site name from settings with fallback.
+func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
+ name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
+ if err != nil || name == "" {
+ return "Sub2API"
+ }
+ return name
+}
+
+// collectBalanceNotifyRecipients collects all email recipients for balance notifications.
+func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
+ recipients := []string{user.Email}
+ for _, extra := range user.BalanceNotifyExtraEmails {
+ email := strings.TrimSpace(extra)
+ if email != "" && email != user.Email {
+ recipients = append(recipients, email)
+ }
+ }
+ return recipients
+}
+
+// sendEmails sends an email to all recipients with shared timeout and error logging.
+func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) {
+ ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
+ defer cancel()
+ for _, to := range recipients {
+ if err := s.emailService.SendEmail(ctx, to, subject, body); err != nil {
+ attrs := append([]any{"to", to, "error", err}, logAttrs...)
+ slog.Error("failed to send notification", attrs...)
+ }
+ }
+}
+
+// sendBalanceLowEmails sends balance low notification to all recipients.
+func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName string) {
+ displayName := userName
+ if displayName == "" {
+ displayName = userEmail
+ }
+ subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", siteName)
+ body := s.buildBalanceLowEmailBody(displayName, balance, threshold, siteName)
+ s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
+}
+
+// sendQuotaAlertEmails sends quota alert notification to admin emails.
+func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountName, dimension string, used, limit, threshold float64, siteName string) {
+ dimLabel := quotaDimLabels[dimension]
+ if dimLabel == "" {
+ dimLabel = dimension
+ }
+
+ subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", siteName, accountName)
+ body := s.buildQuotaAlertEmailBody(accountName, dimLabel, used, limit, threshold, siteName)
+ s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension)
+}
+
+// buildBalanceLowEmailBody builds HTML email for balance low notification.
+func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string {
+ return fmt.Sprintf(`
+
+
+
+
+
+
+
+
+
+
%s,您的余额不足
+
Dear %s, your balance is running low
+
$%.2f
+
+
您的账户余额已低于提醒阈值 $%.2f 。
+
Your account balance has fallen below the alert threshold of $%.2f .
+
请及时充值以免服务中断。
+
Please top up to avoid service interruption.
+
+
+
+
+
+`, siteName, userName, userName, balance, threshold, threshold)
+}
+
+// buildQuotaAlertEmailBody builds HTML email for account quota alert.
+func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel string, used, limit, threshold float64, siteName string) string {
+ limitStr := fmt.Sprintf("$%.2f", limit)
+ if limit <= 0 {
+ limitStr = "无限制 / Unlimited"
+ }
+ return fmt.Sprintf(`
+
+
+
+
+
+
+
+
+
+
账号限额告警 / Account Quota Alert
+
账号 / Account %s
+
维度 / Dimension %s
+
已使用 / Used $%.2f
+
限额 / Limit %s
+
告警阈值 / Threshold $%.2f
+
+
账号配额用量已达到告警阈值,请及时关注。
+
Account quota usage has reached the alert threshold.
+
+
+
+
+
+`, siteName, accountName, dimLabel, used, limitStr, threshold)
+}
+
+// parseJSONStringArray parses a JSON string array, returns nil on error.
+func parseJSONStringArray(raw string) []string {
+ raw = strings.TrimSpace(raw)
+ if raw == "" || raw == "[]" {
+ return nil
+ }
+ var result []string
+ if err := json.Unmarshal([]byte(raw), &result); err != nil {
+ return nil
+ }
+ return result
+}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index f43d388b..2704e0d0 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -250,9 +250,12 @@ const (
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning = "enable_cch_signing"
- // Web Search Emulation
- // SettingKeyWebSearchEmulationConfig 全局 web search 模拟配置(JSON)
- SettingKeyWebSearchEmulationConfig = "web_search_emulation_config"
+ // Balance Low Notification
+ SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
+ SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD)
+
+ // Account Quota Notification
+ SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组)
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go
index 00691233..61090776 100644
--- a/backend/internal/service/email_service.go
+++ b/backend/internal/service/email_service.go
@@ -34,6 +34,11 @@ type EmailCache interface {
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
+ // Notify email verification code methods
+ GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error)
+ SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
+ DeleteNotifyVerifyCode(ctx context.Context, email string) error
+
// Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go
index 97703a9d..140bdc67 100644
--- a/backend/internal/service/gateway_record_usage_test.go
+++ b/backend/internal/service/gateway_record_usage_test.go
@@ -43,6 +43,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
+ nil,
)
}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 1d6d0a08..72ab39ce 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -569,6 +569,7 @@ type GatewayService struct {
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
+ balanceNotifyService *BalanceNotifyService
}
// NewGatewayService creates a new GatewayService
@@ -598,6 +599,7 @@ func NewGatewayService(
tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService,
resolver *ModelPricingResolver,
+ balanceNotifyService *BalanceNotifyService,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
@@ -632,6 +634,7 @@ func NewGatewayService(
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
+ balanceNotifyService: balanceNotifyService,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
@@ -7334,6 +7337,20 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
}
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
+
+ // Balance low notification
+ if !p.IsSubscriptionBill && p.Cost.ActualCost > 0 && p.User != nil && deps.balanceNotifyService != nil {
+ deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, p.User.Balance, p.Cost.ActualCost)
+ }
+
+ // Account quota notification
+ if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil {
+ accountCost := p.Cost.TotalCost
+ if p.AccountRateMultiplier > 0 {
+ accountCost *= p.AccountRateMultiplier
+ }
+ deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost)
+ }
}
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
@@ -7356,20 +7373,22 @@ func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Cont
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct {
- accountRepo AccountRepository
- userRepo UserRepository
- userSubRepo UserSubscriptionRepository
- billingCacheService *BillingCacheService
- deferredService *DeferredService
+ accountRepo AccountRepository
+ userRepo UserRepository
+ userSubRepo UserSubscriptionRepository
+ billingCacheService *BillingCacheService
+ deferredService *DeferredService
+ balanceNotifyService *BalanceNotifyService
}
func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{
- accountRepo: s.accountRepo,
- userRepo: s.userRepo,
- userSubRepo: s.userSubRepo,
- billingCacheService: s.billingCacheService,
- deferredService: s.deferredService,
+ accountRepo: s.accountRepo,
+ userRepo: s.userRepo,
+ userSubRepo: s.userSubRepo,
+ billingCacheService: s.billingCacheService,
+ deferredService: s.deferredService,
+ balanceNotifyService: s.balanceNotifyService,
}
}
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
index 38b97b11..e6fa94aa 100644
--- a/backend/internal/service/openai_gateway_record_usage_test.go
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -147,6 +147,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
+ nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 3daa8756..70abd4ce 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -327,6 +327,7 @@ type OpenAIGatewayService struct {
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
channelService *ChannelService
+ balanceNotifyService *BalanceNotifyService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -364,6 +365,7 @@ func NewOpenAIGatewayService(
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
channelService *ChannelService,
+ balanceNotifyService *BalanceNotifyService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -393,6 +395,7 @@ func NewOpenAIGatewayService(
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
channelService: channelService,
+ balanceNotifyService: balanceNotifyService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@@ -477,11 +480,12 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
- accountRepo: s.accountRepo,
- userRepo: s.userRepo,
- userSubRepo: s.userSubRepo,
- billingCacheService: s.billingCacheService,
- deferredService: s.deferredService,
+ accountRepo: s.accountRepo,
+ userRepo: s.userRepo,
+ userSubRepo: s.userSubRepo,
+ billingCacheService: s.billingCacheService,
+ deferredService: s.deferredService,
+ balanceNotifyService: s.balanceNotifyService,
}
}
diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go
index 3834dcb7..66e5db93 100644
--- a/backend/internal/service/openai_ws_protocol_forward_test.go
+++ b/backend/internal/service/openai_ws_protocol_forward_test.go
@@ -617,6 +617,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
+ nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 3cfe5e56..bc4f53ce 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -18,7 +18,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/imroc/req/v3"
- "github.com/redis/go-redis/v9"
"golang.org/x/sync/singleflight"
)
@@ -107,7 +106,6 @@ type SettingService struct {
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
- webSearchRedis *redis.Client // optional: Redis client for web search quota tracking
}
// NewSettingService 创建系统设置服务实例
@@ -170,9 +168,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
+ SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
- SettingPaymentEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -237,9 +235,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
+ PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
- PaymentEnabled: settings[SettingPaymentEnabled] == "true",
}, nil
}
@@ -289,9 +287,9 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
+ PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
- PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
@@ -319,9 +317,9 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
+ PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
- PaymentEnabled: settings.PaymentEnabled,
Version: s.version,
}, nil
}
@@ -597,6 +595,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
+ // Balance low notification
+ updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
+ updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
+ accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails)
+ if err != nil {
+ return fmt.Errorf("marshal account quota notify emails: %w", err)
+ }
+ updates[SettingKeyAccountQuotaNotifyEmails] = string(accountQuotaNotifyEmailsJSON)
+
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
@@ -1219,13 +1226,22 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
- // Web search emulation: quick enabled check from the JSON config
- if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
- var wsCfg WebSearchEmulationConfig
- if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
- result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
+ // Balance low notification
+ result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
+ if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
+ result.BalanceLowNotifyThreshold = v
+ }
+
+ // Account quota notification emails
+ if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" {
+ var emails []string
+ if err := json.Unmarshal([]byte(raw), &emails); err == nil {
+ result.AccountQuotaNotifyEmails = emails
}
}
+ if result.AccountQuotaNotifyEmails == nil {
+ result.AccountQuotaNotifyEmails = []string{}
+ }
return result
}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index f5535bca..debc2b19 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -107,8 +107,12 @@ type SystemSettings struct {
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
- // Web Search Emulation (read-only quick check; full config via dedicated API)
- WebSearchEmulationEnabled bool
+ // Balance low notification
+ BalanceLowNotifyEnabled bool
+ BalanceLowNotifyThreshold float64
+
+ // Account quota notification
+ AccountQuotaNotifyEmails []string
}
type DefaultSubscriptionSetting struct {
@@ -144,9 +148,9 @@ type PublicSettings struct {
LinuxDoOAuthEnabled bool
BackendModeEnabled bool
+ PaymentEnabled bool
OIDCOAuthEnabled bool
OIDCOAuthProviderName string
- PaymentEnabled bool
Version string
}
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index e56d83bf..b4818223 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -30,6 +30,11 @@ type User struct {
TotpEnabled bool // 是否启用 TOTP
TotpEnabledAt *time.Time // TOTP 启用时间
+ // 余额不足通知
+ BalanceNotifyEnabled bool
+ BalanceNotifyThreshold *float64
+ BalanceNotifyExtraEmails []string
+
APIKeys []APIKey
Subscriptions []UserSubscription
}
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 4045c0aa..e6b9a210 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -2,8 +2,10 @@ package service
import (
"context"
+ "crypto/subtle"
"fmt"
"log"
+ "strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -16,6 +18,8 @@ var (
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
)
+const maxNotifyExtraEmails = 5
+
// UserListFilters contains all filter options for listing users
type UserListFilters struct {
Status string // User status filter
@@ -58,9 +62,11 @@ type UserRepository interface {
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
- Email *string `json:"email"`
- Username *string `json:"username"`
- Concurrency *int `json:"concurrency"`
+ Email *string `json:"email"`
+ Username *string `json:"username"`
+ Concurrency *int `json:"concurrency"`
+ BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// ChangePasswordRequest 修改密码请求
@@ -72,14 +78,16 @@ type ChangePasswordRequest struct {
// UserService 用户服务
type UserService struct {
userRepo UserRepository
+ settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
}
// NewUserService 创建用户服务实例
-func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
+func NewUserService(userRepo UserRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
return &UserService{
userRepo: userRepo,
+ settingRepo: settingRepo,
authCacheInvalidator: authCacheInvalidator,
billingCache: billingCache,
}
@@ -132,6 +140,17 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Concurrency = *req.Concurrency
}
+ if req.BalanceNotifyEnabled != nil {
+ user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled
+ }
+ if req.BalanceNotifyThreshold != nil {
+ if *req.BalanceNotifyThreshold <= 0 {
+ user.BalanceNotifyThreshold = nil // clear to system default
+ } else {
+ user.BalanceNotifyThreshold = req.BalanceNotifyThreshold
+ }
+ }
+
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
@@ -248,3 +267,148 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
}
return nil
}
+
+// SendNotifyEmailCode sends a verification code to the extra notification email.
+func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
+ // Check cooldown
+ existing, err := cache.GetNotifyVerifyCode(ctx, email)
+ if err == nil && existing != nil {
+ if time.Since(existing.CreatedAt) < verifyCodeCooldown {
+ return ErrVerifyCodeTooFrequent
+ }
+ }
+
+ // Generate code
+ code, err := emailService.GenerateVerifyCode()
+ if err != nil {
+ return fmt.Errorf("generate code: %w", err)
+ }
+
+ // Save to cache
+ data := &VerificationCodeData{
+ Code: code,
+ Attempts: 0,
+ CreatedAt: time.Now(),
+ }
+ if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
+ return fmt.Errorf("save verify code: %w", err)
+ }
+
+ // Get site name
+ siteName := "Sub2API"
+ if s.settingRepo != nil {
+ if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
+ siteName = name
+ }
+ }
+
+ // Build and send email
+ subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
+ body := buildNotifyVerifyEmailBody(code, siteName)
+ return emailService.SendEmail(ctx, email, subject, body)
+}
+
+// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
+func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error {
+ // Verify code
+ data, err := cache.GetNotifyVerifyCode(ctx, email)
+ if err != nil || data == nil {
+ return ErrInvalidVerifyCode
+ }
+ if data.Attempts >= maxVerifyCodeAttempts {
+ return ErrVerifyCodeMaxAttempts
+ }
+ if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
+ data.Attempts++
+ _ = cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL)
+ if data.Attempts >= maxVerifyCodeAttempts {
+ return ErrVerifyCodeMaxAttempts
+ }
+ return ErrInvalidVerifyCode
+ }
+
+ // Delete code after verification
+ _ = cache.DeleteNotifyVerifyCode(ctx, email)
+
+ // Add to user's extra emails
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return err
+ }
+
+ // Check if already exists
+ for _, e := range user.BalanceNotifyExtraEmails {
+ if strings.EqualFold(e, email) {
+ return nil // Already added
+ }
+ }
+
+ // Check limit
+ if len(user.BalanceNotifyExtraEmails) >= maxNotifyExtraEmails {
+ return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d extra notification emails allowed", maxNotifyExtraEmails))
+ }
+
+ user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, email)
+ return s.userRepo.Update(ctx, user)
+}
+
+// RemoveNotifyEmail removes an email from user's extra notification emails.
+func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email string) error {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return err
+ }
+
+ filtered := make([]string, 0, len(user.BalanceNotifyExtraEmails))
+ for _, e := range user.BalanceNotifyExtraEmails {
+ if !strings.EqualFold(e, email) {
+ filtered = append(filtered, e)
+ }
+ }
+ user.BalanceNotifyExtraEmails = filtered
+ return s.userRepo.Update(ctx, user)
+}
+
+// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
+func buildNotifyVerifyEmailBody(code, siteName string) string {
+ return fmt.Sprintf(`
+
+
+
+
+
+
+
+
+
+
+
通知邮箱验证码 / Notification Email Verification
+
%s
+
+
您正在添加额外的通知邮箱,请输入此验证码完成验证。
+
You are adding an extra notification email. Please enter this code to verify.
+
此验证码将在 15 分钟 后失效。
+
This code will expire in 15 minutes .
+
如果您没有请求此验证码,请忽略此邮件。
+
If you did not request this code, please ignore this email.
+
+
+
+
+
+
+`, siteName, code)
+}
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index 7f6c748f..29267c19 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -114,7 +114,7 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err
func TestUpdateBalance_Success(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{}
- svc := NewUserService(repo, nil, cache)
+ svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 42, 100.0)
require.NoError(t, err)
@@ -131,7 +131,7 @@ func TestUpdateBalance_Success(t *testing.T) {
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
- svc := NewUserService(repo, nil, nil) // billingCache = nil
+ svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
err := svc.UpdateBalance(context.Background(), 1, 50.0)
require.NoError(t, err, "billingCache 为 nil 时不应 panic")
@@ -140,7 +140,7 @@ func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
- svc := NewUserService(repo, nil, cache)
+ svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 99, 200.0)
require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
@@ -154,7 +154,7 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
- svc := NewUserService(repo, nil, cache)
+ svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 1, 100.0)
require.Error(t, err, "repo 失败时应返回错误")
@@ -170,7 +170,7 @@ func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
- svc := NewUserService(repo, auth, cache)
+ svc := NewUserService(repo, nil, auth, cache)
err := svc.UpdateBalance(context.Background(), 77, 300.0)
require.NoError(t, err)
@@ -191,7 +191,7 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
- svc := NewUserService(repo, auth, cache)
+ svc := NewUserService(repo, nil, auth, cache)
require.NotNil(t, svc)
require.Equal(t, repo, svc.userRepo)
require.Equal(t, auth, svc.authCacheInvalidator)
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index a8ece8a3..2827f135 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -465,6 +465,7 @@ var ProviderSet = wire.NewSet(
ProvidePaymentConfigService,
NewPaymentService,
ProvidePaymentOrderExpiryService,
+ ProvideBalanceNotifyService,
)
// ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named
@@ -473,6 +474,11 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
return NewPaymentConfigService(entClient, settingRepo, []byte(key))
}
+// ProvideBalanceNotifyService creates BalanceNotifyService
+func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService {
+ return NewBalanceNotifyService(emailService, settingRepo)
+}
+
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
func ProvidePaymentOrderExpiryService(paymentSvc *PaymentService) *PaymentOrderExpiryService {
svc := NewPaymentOrderExpiryService(paymentSvc, 60*time.Second)
diff --git a/backend/migrations/101_add_balance_notify_fields.sql b/backend/migrations/101_add_balance_notify_fields.sql
new file mode 100644
index 00000000..ef0a0930
--- /dev/null
+++ b/backend/migrations/101_add_balance_notify_fields.sql
@@ -0,0 +1,4 @@
+-- Balance notification user preferences
+ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_enabled BOOLEAN NOT NULL DEFAULT true;
+ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_threshold DECIMAL(20,8) DEFAULT NULL;
+ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_extra_emails TEXT NOT NULL DEFAULT '[]';
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 7fc6c852..c6323b00 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -134,6 +134,11 @@ export interface SystemSettings {
payment_cancel_rate_limit_window: number
payment_cancel_rate_limit_unit: string
payment_cancel_rate_limit_window_mode: string
+
+ // Balance & quota notification
+ balance_low_notify_enabled: boolean
+ balance_low_notify_threshold: number
+ account_quota_notify_emails: string[]
}
export interface UpdateSettingsRequest {
@@ -233,6 +238,10 @@ export interface UpdateSettingsRequest {
payment_cancel_rate_limit_window?: number
payment_cancel_rate_limit_unit?: string
payment_cancel_rate_limit_window_mode?: string
+ // Balance & quota notification
+ balance_low_notify_enabled?: boolean
+ balance_low_notify_threshold?: number
+ account_quota_notify_emails?: string[]
}
/**
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index bfc0e30b..9ef0f59c 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -22,6 +22,9 @@ export async function getProfile(): Promise {
*/
export async function updateProfile(profile: {
username?: string
+ balance_notify_enabled?: boolean
+ balance_notify_threshold?: number | null
+ balance_notify_extra_emails?: string[]
}): Promise {
const { data } = await apiClient.put('/user', profile)
return data
@@ -45,10 +48,38 @@ export async function changePassword(
return data
}
+/**
+ * Send verification code for adding a notify email
+ * @param email - Email address to verify
+ */
+export async function sendNotifyEmailCode(email: string): Promise {
+ await apiClient.post('/user/notify-email/send-code', { email })
+}
+
+/**
+ * Verify and add a notify email
+ * @param email - Email address to add
+ * @param code - Verification code
+ */
+export async function verifyNotifyEmail(email: string, code: string): Promise {
+ await apiClient.post('/user/notify-email/verify', { email, code })
+}
+
+/**
+ * Remove a notify email
+ * @param email - Email address to remove
+ */
+export async function removeNotifyEmail(email: string): Promise {
+ await apiClient.delete('/user/notify-email', { data: { email } })
+}
+
export const userAPI = {
getProfile,
updateProfile,
- changePassword
+ changePassword,
+ sendNotifyEmailCode,
+ verifyNotifyEmail,
+ removeNotifyEmail
}
export default userAPI
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index a67366fc..54dae5c2 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -1186,6 +1186,12 @@
:weeklyResetDay="editWeeklyResetDay"
:weeklyResetHour="editWeeklyResetHour"
:resetTimezone="editResetTimezone"
+ :quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled"
+ :quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold"
+ :quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled"
+ :quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold"
+ :quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled"
+ :quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold"
@update:totalLimit="editQuotaLimit = $event"
@update:dailyLimit="editQuotaDailyLimit = $event"
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
@@ -1195,6 +1201,12 @@
@update:weeklyResetDay="editWeeklyResetDay = $event"
@update:weeklyResetHour="editWeeklyResetHour = $event"
@update:resetTimezone="editResetTimezone = $event"
+ @update:quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled = $event"
+ @update:quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold = $event"
+ @update:quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled = $event"
+ @update:quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold = $event"
+ @update:quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled = $event"
+ @update:quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold = $event"
/>
@@ -1218,6 +1230,12 @@
:weeklyResetDay="editWeeklyResetDay"
:weeklyResetHour="editWeeklyResetHour"
:resetTimezone="editResetTimezone"
+ :quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled"
+ :quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold"
+ :quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled"
+ :quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold"
+ :quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled"
+ :quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold"
@update:totalLimit="editQuotaLimit = $event"
@update:dailyLimit="editQuotaDailyLimit = $event"
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
@@ -1227,6 +1245,12 @@
@update:weeklyResetDay="editWeeklyResetDay = $event"
@update:weeklyResetHour="editWeeklyResetHour = $event"
@update:resetTimezone="editResetTimezone = $event"
+ @update:quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled = $event"
+ @update:quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold = $event"
+ @update:quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled = $event"
+ @update:quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold = $event"
+ @update:quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled = $event"
+ @update:quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold = $event"
/>
@@ -1960,6 +1984,12 @@ const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null)
const editWeeklyResetDay = ref(null)
const editWeeklyResetHour = ref(null)
const editResetTimezone = ref(null)
+const editQuotaNotifyDailyEnabled = ref(null)
+const editQuotaNotifyDailyThreshold = ref(null)
+const editQuotaNotifyWeeklyEnabled = ref(null)
+const editQuotaNotifyWeeklyThreshold = ref(null)
+const editQuotaNotifyTotalEnabled = ref(null)
+const editQuotaNotifyTotalThreshold = ref(null)
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
@@ -2159,6 +2189,13 @@ const syncFormFromAccount = (newAccount: Account | null) => {
editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null
editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null
editResetTimezone.value = (extra?.quota_reset_timezone as string) || null
+ // Load quota notify config
+ editQuotaNotifyDailyEnabled.value = (extra?.quota_notify_daily_enabled as boolean) ?? null
+ editQuotaNotifyDailyThreshold.value = (extra?.quota_notify_daily_threshold as number) ?? null
+ editQuotaNotifyWeeklyEnabled.value = (extra?.quota_notify_weekly_enabled as boolean) ?? null
+ editQuotaNotifyWeeklyThreshold.value = (extra?.quota_notify_weekly_threshold as number) ?? null
+ editQuotaNotifyTotalEnabled.value = (extra?.quota_notify_total_enabled as boolean) ?? null
+ editQuotaNotifyTotalThreshold.value = (extra?.quota_notify_total_threshold as number) ?? null
} else {
editQuotaLimit.value = null
editQuotaDailyLimit.value = null
@@ -2169,6 +2206,12 @@ const syncFormFromAccount = (newAccount: Account | null) => {
editWeeklyResetDay.value = null
editWeeklyResetHour.value = null
editResetTimezone.value = null
+ editQuotaNotifyDailyEnabled.value = null
+ editQuotaNotifyDailyThreshold.value = null
+ editQuotaNotifyWeeklyEnabled.value = null
+ editQuotaNotifyWeeklyThreshold.value = null
+ editQuotaNotifyTotalEnabled.value = null
+ editQuotaNotifyTotalThreshold.value = null
}
// Load antigravity model mapping (Antigravity 只支持映射模式)
@@ -2283,6 +2326,13 @@ const syncFormFromAccount = (newAccount: Account | null) => {
editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null
editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null
editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null
+ // Load quota notify for bedrock
+ editQuotaNotifyDailyEnabled.value = (bedrockExtra.quota_notify_daily_enabled as boolean) ?? null
+ editQuotaNotifyDailyThreshold.value = (bedrockExtra.quota_notify_daily_threshold as number) ?? null
+ editQuotaNotifyWeeklyEnabled.value = (bedrockExtra.quota_notify_weekly_enabled as boolean) ?? null
+ editQuotaNotifyWeeklyThreshold.value = (bedrockExtra.quota_notify_weekly_threshold as number) ?? null
+ editQuotaNotifyTotalEnabled.value = (bedrockExtra.quota_notify_total_enabled as boolean) ?? null
+ editQuotaNotifyTotalThreshold.value = (bedrockExtra.quota_notify_total_threshold as number) ?? null
// Load model mappings for bedrock
const existingMappings = bedrockCreds.model_mapping as Record | undefined
@@ -3198,6 +3248,40 @@ const handleSubmit = async () => {
} else {
delete newExtra.quota_reset_timezone
}
+ // Quota notify config
+ if (editQuotaNotifyDailyEnabled.value) {
+ newExtra.quota_notify_daily_enabled = true
+ if (editQuotaNotifyDailyThreshold.value != null) {
+ newExtra.quota_notify_daily_threshold = editQuotaNotifyDailyThreshold.value
+ } else {
+ delete newExtra.quota_notify_daily_threshold
+ }
+ } else {
+ delete newExtra.quota_notify_daily_enabled
+ delete newExtra.quota_notify_daily_threshold
+ }
+ if (editQuotaNotifyWeeklyEnabled.value) {
+ newExtra.quota_notify_weekly_enabled = true
+ if (editQuotaNotifyWeeklyThreshold.value != null) {
+ newExtra.quota_notify_weekly_threshold = editQuotaNotifyWeeklyThreshold.value
+ } else {
+ delete newExtra.quota_notify_weekly_threshold
+ }
+ } else {
+ delete newExtra.quota_notify_weekly_enabled
+ delete newExtra.quota_notify_weekly_threshold
+ }
+ if (editQuotaNotifyTotalEnabled.value) {
+ newExtra.quota_notify_total_enabled = true
+ if (editQuotaNotifyTotalThreshold.value != null) {
+ newExtra.quota_notify_total_threshold = editQuotaNotifyTotalThreshold.value
+ } else {
+ delete newExtra.quota_notify_total_threshold
+ }
+ } else {
+ delete newExtra.quota_notify_total_enabled
+ delete newExtra.quota_notify_total_threshold
+ }
updatePayload.extra = newExtra
}
diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue
index fdc19ad9..9840a5e1 100644
--- a/frontend/src/components/account/QuotaLimitCard.vue
+++ b/frontend/src/components/account/QuotaLimitCard.vue
@@ -4,7 +4,7 @@ import { useI18n } from 'vue-i18n'
const { t } = useI18n()
-const props = defineProps<{
+const props = withDefaults(defineProps<{
totalLimit: number | null
dailyLimit: number | null
weeklyLimit: number | null
@@ -14,7 +14,20 @@ const props = defineProps<{
weeklyResetDay: number | null
weeklyResetHour: number | null
resetTimezone: string | null
-}>()
+ quotaNotifyDailyEnabled?: boolean | null
+ quotaNotifyDailyThreshold?: number | null
+ quotaNotifyWeeklyEnabled?: boolean | null
+ quotaNotifyWeeklyThreshold?: number | null
+ quotaNotifyTotalEnabled?: boolean | null
+ quotaNotifyTotalThreshold?: number | null
+}>(), {
+ quotaNotifyDailyEnabled: null,
+ quotaNotifyDailyThreshold: null,
+ quotaNotifyWeeklyEnabled: null,
+ quotaNotifyWeeklyThreshold: null,
+ quotaNotifyTotalEnabled: null,
+ quotaNotifyTotalThreshold: null,
+})
const emit = defineEmits<{
'update:totalLimit': [value: number | null]
@@ -26,6 +39,12 @@ const emit = defineEmits<{
'update:weeklyResetDay': [value: number | null]
'update:weeklyResetHour': [value: number | null]
'update:resetTimezone': [value: string | null]
+ 'update:quotaNotifyDailyEnabled': [value: boolean | null]
+ 'update:quotaNotifyDailyThreshold': [value: number | null]
+ 'update:quotaNotifyWeeklyEnabled': [value: boolean | null]
+ 'update:quotaNotifyWeeklyThreshold': [value: number | null]
+ 'update:quotaNotifyTotalEnabled': [value: boolean | null]
+ 'update:quotaNotifyTotalThreshold': [value: number | null]
}>()
const enabled = computed(() =>
@@ -203,6 +222,36 @@ const onWeeklyModeChange = (e: Event) => {
{{ t('admin.accounts.quotaDailyLimitHint') }}
+
+
+
{{ t('admin.accounts.quotaNotify.alert') }}
+
+
+
+
+ $
+
+
+
@@ -259,6 +308,36 @@ const onWeeklyModeChange = (e: Event) => {
{{ t('admin.accounts.quotaWeeklyLimitHint') }}
+
+
+
{{ t('admin.accounts.quotaNotify.alert') }}
+
+
+
+
+ $
+
+
+
@@ -289,6 +368,36 @@ const onWeeklyModeChange = (e: Event) => {
/>
{{ t('admin.accounts.quotaTotalLimitHint') }}
+
+
+
{{ t('admin.accounts.quotaNotify.alert') }}
+
+
+
+
+ $
+
+
+
diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
new file mode 100644
index 00000000..130d82b5
--- /dev/null
+++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
@@ -0,0 +1,204 @@
+
+
+
+
+ {{ t('profile.balanceNotify.title') }}
+
+
+
+
+
+
+ {{ t('profile.balanceNotify.enabled') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('profile.balanceNotify.threshold') }}
+ {{ t('profile.balanceNotify.thresholdHint') }}
+
+
+ $
+
+
+
+
+
+
+
{{ t('profile.balanceNotify.extraEmails') }}
+
+
+
+
+ {{ email }}
+
+ {{ t('profile.balanceNotify.removeEmail') }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index dd45ea17..7119fa36 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -902,6 +902,31 @@ export default {
sendCode: 'Send Code',
codeSent: 'Verification code sent to your email',
sendCodeFailed: 'Failed to send verification code'
+ },
+ balanceNotify: {
+ title: 'Balance Low Notification',
+ description: 'Send email alert when account balance falls below threshold',
+ enabled: 'Enable Balance Low Notification',
+ threshold: 'Custom Threshold',
+ thresholdHint: 'Leave empty to use system default',
+ thresholdPlaceholder: 'Enter amount',
+ systemDefault: 'System Default',
+ extraEmails: 'Extra Notification Emails',
+ noExtraEmails: 'No extra notification emails',
+ enterEmail: 'Enter email address',
+ addEmail: 'Add Email',
+ emailPlaceholder: 'Enter email address',
+ sendCode: 'Send Code',
+ codeSent: 'Verification code sent',
+ codeSentTo: 'Code sent to {email}',
+ enterCode: 'Enter verification code',
+ codePlaceholder: '6-digit code',
+ verify: 'Verify & Add',
+ emailAdded: 'Email added',
+ emailRemoved: 'Email removed',
+ verifySuccess: 'Email added successfully',
+ removeEmail: 'Remove',
+ removeSuccess: 'Email removed',
}
},
@@ -2228,6 +2253,12 @@ export default {
},
quotaLimitAmount: 'Total Limit',
quotaLimitAmountHint: 'Cumulative spending limit. Does not auto-reset.',
+ quotaNotify: {
+ alert: 'Alert Threshold',
+ enabled: 'Enable Alert',
+ threshold: 'Alert Amount',
+ thresholdPlaceholder: 'Enter alert amount',
+ },
testConnection: 'Test Connection',
reAuthorize: 'Re-Authorize',
refreshToken: 'Refresh Token',
@@ -4593,6 +4624,22 @@ export default {
supportedTypesHint: 'Comma-separated, e.g. alipay,wxpay',
refundEnabled: 'Allow Refund',
},
+ balanceNotify: {
+ title: 'Balance Low Notification',
+ description: 'Send email notification when user balance falls below threshold',
+ enabled: 'Enable Balance Low Notification',
+ threshold: 'Default Threshold',
+ thresholdHint: 'Used when user has not set a custom value',
+ thresholdPlaceholder: 'Enter amount',
+ },
+ quotaNotify: {
+ title: 'Account Quota Notification',
+ description: 'Notify admins when account quota usage reaches alert threshold',
+ emails: 'Notification Emails',
+ emailsHint: 'Leave empty to disable notifications',
+ addEmail: 'Add Email',
+ emailPlaceholder: 'Enter email address',
+ },
smtp: {
title: 'SMTP Settings',
description: 'Configure email sending for verification codes',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index bbfc7971..6efaf657 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -906,6 +906,31 @@ export default {
sendCode: '发送验证码',
codeSent: '验证码已发送到您的邮箱',
sendCodeFailed: '发送验证码失败'
+ },
+ balanceNotify: {
+ title: '余额不足提醒',
+ description: '当账户余额低于阈值时发送邮件提醒',
+ enabled: '启用余额不足提醒',
+ threshold: '自定义提醒阈值',
+ thresholdHint: '留空使用系统默认值',
+ thresholdPlaceholder: '输入金额',
+ systemDefault: '系统默认值',
+ extraEmails: '额外通知邮箱',
+ noExtraEmails: '暂无额外通知邮箱',
+ enterEmail: '输入邮箱地址',
+ addEmail: '添加邮箱',
+ emailPlaceholder: '输入邮箱地址',
+ sendCode: '发送验证码',
+ codeSent: '验证码已发送',
+ codeSentTo: '验证码已发送到 {email}',
+ enterCode: '输入验证码',
+ codePlaceholder: '6位验证码',
+ verify: '确认添加',
+ emailAdded: '邮箱已添加',
+ emailRemoved: '邮箱已移除',
+ verifySuccess: '邮箱添加成功',
+ removeEmail: '移除',
+ removeSuccess: '邮箱已移除',
}
},
@@ -2226,6 +2251,12 @@ export default {
},
quotaLimitAmount: '总限额',
quotaLimitAmountHint: '累计消费上限,不会自动重置。',
+ quotaNotify: {
+ alert: '告警阈值',
+ enabled: '启用告警',
+ threshold: '告警金额',
+ thresholdPlaceholder: '输入告警金额',
+ },
testConnection: '测试连接',
reAuthorize: '重新授权',
refreshToken: '刷新令牌',
@@ -4757,6 +4788,22 @@ export default {
supportedTypesHint: '逗号分隔,如 alipay,wxpay',
refundEnabled: '允许退款',
},
+ balanceNotify: {
+ title: '余额不足提醒',
+ description: '当用户余额低于阈值时发送邮件提醒',
+ enabled: '启用余额不足提醒',
+ threshold: '默认提醒阈值',
+ thresholdHint: '用户未自定义时使用此值',
+ thresholdPlaceholder: '输入金额',
+ },
+ quotaNotify: {
+ title: '账号限额通知',
+ description: '当账号配额用量达到告警阈值时通知管理员',
+ emails: '通知邮箱',
+ emailsHint: '留空则不发送通知',
+ addEmail: '添加邮箱',
+ emailPlaceholder: '输入邮箱地址',
+ },
smtp: {
title: 'SMTP 设置',
description: '配置用于发送验证码的邮件服务',
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 8b7e44c1..e74f6e61 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -33,6 +33,9 @@ export interface User {
concurrency: number // Allowed concurrent requests
status: 'active' | 'disabled' // Account status
allowed_groups: number[] | null // Allowed group IDs (null = all non-exclusive groups)
+ balance_notify_enabled: boolean
+ balance_notify_threshold: number | null
+ balance_notify_extra_emails: string[]
subscriptions?: UserSubscription[] // User's active subscriptions
created_at: string
updated_at: string
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 6abc725a..f57bfcf3 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2562,6 +2562,60 @@
+
+
+
+
+ {{ t('admin.settings.balanceNotify.title') }}
+
+
+ {{ t('admin.settings.balanceNotify.description') }}
+
+
+
+
+ {{ t('admin.settings.balanceNotify.enabled') }}
+
+
+
+
{{ t('admin.settings.balanceNotify.threshold') }}
+
+ $
+
+
+
{{ t('admin.settings.balanceNotify.thresholdHint') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.quotaNotify.title') }}
+
+
+ {{ t('admin.settings.quotaNotify.description') }}
+
+
+
+
+
{{ t('admin.settings.quotaNotify.emails') }}
+
+
+
+
+
+
+
+
+ + {{ t('admin.settings.quotaNotify.addEmail') }}
+
+
+
{{ t('admin.settings.quotaNotify.emailsHint') }}
+
+
+
@@ -2840,7 +2894,11 @@ const form = reactive({
// Gateway forwarding behavior
enable_fingerprint_unification: true,
enable_metadata_passthrough: false,
- enable_cch_signing: false
+ enable_cch_signing: false,
+ // Balance & quota notification
+ balance_low_notify_enabled: false,
+ balance_low_notify_threshold: 0,
+ account_quota_notify_emails: [] as string[]
})
// Web Search Emulation config (loaded/saved separately)
@@ -2972,6 +3030,14 @@ function handleRegistrationEmailSuffixWhitelistPaste(event: ClipboardEvent) {
}
}
+// Quota notify email helpers
+const addQuotaNotifyEmail = () => {
+ if (!form.account_quota_notify_emails) {
+ form.account_quota_notify_emails = []
+ }
+ form.account_quota_notify_emails.push('')
+}
+
// LinuxDo OAuth redirect URL suggestion
const linuxdoRedirectUrlSuggestion = computed(() => {
if (typeof window === 'undefined') return ''
@@ -3311,6 +3377,10 @@ async function saveSettings() {
payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1,
payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit,
payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode,
+ // Balance & quota notification
+ balance_low_notify_enabled: form.balance_low_notify_enabled,
+ balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0,
+ account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''),
}
const updated = await adminAPI.settings.updateSettings(payload)
diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue
index 0967e2b9..5534e1d6 100644
--- a/frontend/src/views/user/ProfileView.vue
+++ b/frontend/src/views/user/ProfileView.vue
@@ -14,6 +14,12 @@
+
@@ -27,6 +33,7 @@ import { authAPI } from '@/api'; import AppLayout from '@/components/layout/AppL
import StatCard from '@/components/common/StatCard.vue'
import ProfileInfoCard from '@/components/user/profile/ProfileInfoCard.vue'
import ProfileEditForm from '@/components/user/profile/ProfileEditForm.vue'
+import ProfileBalanceNotifyCard from '@/components/user/profile/ProfileBalanceNotifyCard.vue'
import ProfilePasswordForm from '@/components/user/profile/ProfilePasswordForm.vue'
import ProfileTotpCard from '@/components/user/profile/ProfileTotpCard.vue'
import { Icon } from '@/components/icons'
From c3812ce1e3fd845fe23a4cbc63f09655fd15fcab Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 12:48:17 +0800
Subject: [PATCH 21/88] fix(notify): address review findings - accountCost
formula, dedup, refactor
- Fix accountCost calculation in finalizePostUsageBilling to match
postUsageBilling (always multiply by AccountRateMultiplier)
- Use strings.EqualFold for email dedup in collectBalanceNotifyRecipients
- Extract CheckAccountQuotaAfterIncrement into smaller functions:
buildQuotaDims + asyncSendQuotaAlert (< 30 lines each)
- Add "not splittable" comments for HTML template functions
- Extract QuotaNotifyToggle.vue sub-component to reduce
QuotaLimitCard.vue from 404 to 339 lines
---
.../service/balance_notify_service.go | 82 ++++++-------
backend/internal/service/gateway_service.go | 7 +-
.../src/components/account/QuotaLimitCard.vue | 109 ++++--------------
.../components/account/QuotaNotifyToggle.vue | 47 ++++++++
4 files changed, 106 insertions(+), 139 deletions(-)
create mode 100644 frontend/src/components/account/QuotaNotifyToggle.vue
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 7cd61a0a..8223a231 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -85,73 +85,59 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
}
}
+// quotaDim describes one quota dimension for notification checking.
+type quotaDim struct {
+ name string
+ enabled bool
+ threshold float64
+ oldUsed float64
+ limit float64
+}
+
+// buildQuotaDims returns the three quota dimensions for notification checking.
+func buildQuotaDims(account *Account) []quotaDim {
+ return []quotaDim{
+ {quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()},
+ {quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()},
+ {quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaUsed(), account.GetQuotaLimit()},
+ }
+}
+
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
// The account's Extra fields contain pre-increment usage values.
func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64) {
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
return
}
-
adminEmails := s.getAccountQuotaNotifyEmails(ctx)
if len(adminEmails) == 0 {
return
}
siteName := s.getSiteName(ctx)
-
- // Check each dimension
- type quotaDim struct {
- name string
- enabled bool
- threshold float64
- oldUsed float64
- limit float64
- }
-
- dims := []quotaDim{
- {
- name: quotaDimDaily,
- enabled: account.GetQuotaNotifyDailyEnabled(),
- threshold: account.GetQuotaNotifyDailyThreshold(),
- oldUsed: account.GetQuotaDailyUsed(),
- limit: account.GetQuotaDailyLimit(),
- },
- {
- name: quotaDimWeekly,
- enabled: account.GetQuotaNotifyWeeklyEnabled(),
- threshold: account.GetQuotaNotifyWeeklyThreshold(),
- oldUsed: account.GetQuotaWeeklyUsed(),
- limit: account.GetQuotaWeeklyLimit(),
- },
- {
- name: quotaDimTotal,
- enabled: account.GetQuotaNotifyTotalEnabled(),
- threshold: account.GetQuotaNotifyTotalThreshold(),
- oldUsed: account.GetQuotaUsed(),
- limit: account.GetQuotaLimit(),
- },
- }
-
- for _, dim := range dims {
+ for _, dim := range buildQuotaDims(account) {
if !dim.enabled || dim.threshold <= 0 {
continue
}
newUsed := dim.oldUsed + cost
- // Only notify on first crossing
if dim.oldUsed < dim.threshold && newUsed >= dim.threshold {
- dimCopy := dim // capture loop variable
- go func() {
- defer func() {
- if r := recover(); r != nil {
- slog.Error("panic in quota notification", "recover", r)
- }
- }()
- s.sendQuotaAlertEmails(adminEmails, account.Name, dimCopy.name, newUsed, dimCopy.limit, dimCopy.threshold, siteName)
- }()
+ s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, siteName)
}
}
}
+// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery.
+func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountName string, dim quotaDim, newUsed float64, siteName string) {
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ slog.Error("panic in quota notification", "recover", r)
+ }
+ }()
+ s.sendQuotaAlertEmails(adminEmails, accountName, dim.name, newUsed, dim.limit, dim.threshold, siteName)
+ }()
+}
+
// getBalanceNotifyConfig reads global balance notification settings.
func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64) {
keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold}
@@ -191,7 +177,7 @@ func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []stri
recipients := []string{user.Email}
for _, extra := range user.BalanceNotifyExtraEmails {
email := strings.TrimSpace(extra)
- if email != "" && email != user.Email {
+ if email != "" && !strings.EqualFold(email, user.Email) {
recipients = append(recipients, email)
}
}
@@ -234,6 +220,7 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
}
// buildBalanceLowEmailBody builds HTML email for balance low notification.
+// Lines exceed 30 due to inline HTML template (not splittable).
func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string {
return fmt.Sprintf(`
@@ -271,6 +258,7 @@ func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance
}
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
+// Lines exceed 30 due to inline HTML template (not splittable).
func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel string, used, limit, threshold float64, siteName string) string {
limitStr := fmt.Sprintf("$%.2f", limit)
if limit <= 0 {
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 72ab39ce..1203f0c6 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -7343,12 +7343,9 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, p.User.Balance, p.Cost.ActualCost)
}
- // Account quota notification
+ // Account quota notification (use same cost formula as postUsageBilling)
if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil {
- accountCost := p.Cost.TotalCost
- if p.AccountRateMultiplier > 0 {
- accountCost *= p.AccountRateMultiplier
- }
+ accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost)
}
}
diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue
index 9840a5e1..7c3afd23 100644
--- a/frontend/src/components/account/QuotaLimitCard.vue
+++ b/frontend/src/components/account/QuotaLimitCard.vue
@@ -1,6 +1,7 @@
+
+
+
+
{{ t('admin.accounts.quotaNotify.alert') }}
+
+
+
+
+ $
+
+
+
+
From 30b926add431f41c57e4c972190920fe3884401e Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 12:50:27 +0800
Subject: [PATCH 22/88] fix(notify): per-recipient timeout and return user on
email removal
- Use per-recipient context timeout in sendEmails to prevent later
recipients from failing due to shared timeout exhaustion
- Return updated user object from RemoveNotifyEmail handler for
frontend state consistency (matching VerifyNotifyEmail pattern)
---
backend/internal/handler/user_handler.go | 9 ++++++++-
backend/internal/service/balance_notify_service.go | 4 ++--
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 42463a7a..4fb72ce7 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -205,5 +205,12 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, gin.H{"message": "Email removed successfully"})
+ // Return updated user
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.UserFromService(updatedUser))
}
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 8223a231..8dd56b8f 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -186,13 +186,13 @@ func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []stri
// sendEmails sends an email to all recipients with shared timeout and error logging.
func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) {
- ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
- defer cancel()
for _, to := range recipients {
+ ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
if err := s.emailService.SendEmail(ctx, to, subject, body); err != nil {
attrs := append([]any{"to", to, "error", err}, logAttrs...)
slog.Error("failed to send notification", attrs...)
}
+ cancel()
}
}
From d0674e0ff94028f32c5f0882b16b66013719d5b0 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 13:11:46 +0800
Subject: [PATCH 23/88] feat(websearch): settings UI overhaul and quota
improvements
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Remove Priority field, auto load-balance by quota remaining
- Replace QuotaRefreshInterval (daily/weekly/monthly) with SubscribedAt
(subscription date, monthly lazy refresh via Redis TTL)
- Add collapsible provider cards, API key show/copy, usage progress bar
- Add test endpoint (POST /web-search-emulation/test) bypassing quota
- Wire WebSearchManagerBuilder on startup (was never called before)
- Fix nextMonthlyReset day-of-month overflow (Jan 31 → Feb 28)
- Fix non-deterministic sort in selectByQuotaWeight
- Map ProxyID in builder for provider-level proxy tracking
- Fix frontend timezone drift in subscribed_at date picker
- Fix provider deletion index shift for expandedProviders state
---
.../internal/handler/admin/setting_handler.go | 86 +++--
backend/internal/pkg/websearch/manager.go | 173 ++++++---
.../internal/pkg/websearch/manager_test.go | 136 ++++---
backend/internal/server/http.go | 30 ++
backend/internal/server/routes/admin.go | 1 +
backend/internal/service/websearch_config.go | 69 ++--
.../internal/service/websearch_config_test.go | 45 +--
frontend/src/api/admin/settings.ts | 22 +-
frontend/src/i18n/locales/en.ts | 21 +-
frontend/src/i18n/locales/zh.ts | 21 +-
frontend/src/views/admin/SettingsView.vue | 351 ++++++++++++------
11 files changed, 627 insertions(+), 328 deletions(-)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 459eade9..e5e024c6 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -175,9 +175,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
- BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
- BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
- AccountQuotaNotifyEmails: settings.AccountQuotaNotifyEmails,
+ WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
@@ -307,11 +305,6 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
- // Balance low notification
- BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
- BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
- AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"`
-
// Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"`
@@ -889,24 +882,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
- BalanceLowNotifyEnabled: func() bool {
- if req.BalanceLowNotifyEnabled != nil {
- return *req.BalanceLowNotifyEnabled
- }
- return previousSettings.BalanceLowNotifyEnabled
- }(),
- BalanceLowNotifyThreshold: func() float64 {
- if req.BalanceLowNotifyThreshold != nil {
- return *req.BalanceLowNotifyThreshold
- }
- return previousSettings.BalanceLowNotifyThreshold
- }(),
- AccountQuotaNotifyEmails: func() []string {
- if req.AccountQuotaNotifyEmails != nil {
- return *req.AccountQuotaNotifyEmails
- }
- return previousSettings.AccountQuotaNotifyEmails
- }(),
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -1053,9 +1028,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
- BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
- BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
- AccountQuotaNotifyEmails: updatedSettings.AccountQuotaNotifyEmails,
PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
@@ -1876,3 +1848,59 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}
+
+// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
+// GET /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
+ cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.SanitizeWebSearchConfig(c.Request.Context(), cfg))
+}
+
+// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
+// PUT /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
+ var cfg service.WebSearchEmulationConfig
+ if err := c.ShouldBindJSON(&cfg); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Re-read (with sanitized api keys) to return current state
+ updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.SanitizeWebSearchConfig(c.Request.Context(), updated))
+}
+
+// TestWebSearchEmulation 测试 Web Search 搜索
+// POST /api/v1/admin/settings/web-search-emulation/test
+func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
+ var req struct {
+ Query string `json:"query"`
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if strings.TrimSpace(req.Query) == "" {
+ req.Query = "搜索今年世界大事件"
+ }
+
+ result, err := service.TestWebSearch(c.Request.Context(), req.Query)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go
index 7db3d9a2..ae0683ad 100644
--- a/backend/internal/pkg/websearch/manager.go
+++ b/backend/internal/pkg/websearch/manager.go
@@ -19,26 +19,18 @@ import (
"github.com/redis/go-redis/v9"
)
-// Quota refresh interval constants.
-const (
- QuotaRefreshDaily = "daily"
- QuotaRefreshWeekly = "weekly"
- QuotaRefreshMonthly = "monthly"
-)
-
// ProviderConfig holds the configuration for a single search provider.
type ProviderConfig struct {
- Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
- APIKey string `json:"api_key"` // secret
- Priority int `json:"priority"` // lower = higher priority
- QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
- QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly
- ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
- ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
- ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
+ Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
+ APIKey string `json:"api_key"` // secret
+ QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
+ SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly from this date
+ ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
+ ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
+ ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
}
-// Manager selects providers by priority and tracks quota via Redis.
+// Manager selects providers by quota-weighted load balancing and tracks quota via Redis.
type Manager struct {
configs []ProviderConfig
redis *redis.Client
@@ -58,6 +50,7 @@ const (
proxyUnavailableKey = "websearch:proxy_unavailable:%d"
proxyUnavailableTTL = 5 * time.Minute
quotaTTLBuffer = 24 * time.Hour
+ defaultQuotaTTL = 31*24*time.Hour + quotaTTLBuffer // fallback when no subscription date
maxCachedClients = 100
)
@@ -80,14 +73,12 @@ return val
`)
// NewManager creates a Manager with the given provider configs and Redis client.
+// Provider order is preserved as-is; selectByQuotaWeight handles load balancing.
func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
- sorted := make([]ProviderConfig, len(configs))
- copy(sorted, configs)
- sort.Slice(sorted, func(i, j int) bool {
- return sorted[i].Priority < sorted[j].Priority
- })
+ copied := make([]ProviderConfig, len(configs))
+ copy(copied, configs)
return &Manager{
- configs: sorted,
+ configs: copied,
redis: redisClient,
clientCache: make(map[string]*http.Client),
}
@@ -162,21 +153,28 @@ type weighted struct {
// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
// Among providers with quota, higher remaining quota = higher priority.
func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
+ items := m.computeWeights(ctx, candidates)
+ withQuota, withoutQuota := partitionByQuota(items)
+ sortByStableRandomWeight(withQuota)
+ return mergeWeightedResults(withQuota, withoutQuota, len(candidates))
+}
+
+func (m *Manager) computeWeights(ctx context.Context, candidates []ProviderConfig) []weighted {
items := make([]weighted, 0, len(candidates))
for _, cfg := range candidates {
w := int64(0)
if cfg.QuotaLimit > 0 {
- used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
- remaining := cfg.QuotaLimit - used
- if remaining > 0 {
+ used, _ := m.GetUsage(ctx, cfg.Type)
+ if remaining := cfg.QuotaLimit - used; remaining > 0 {
w = remaining
}
}
items = append(items, weighted{cfg: cfg, weight: w})
}
+ return items
+}
- // Separate providers with quota (weight > 0) from those without (weight == 0)
- var withQuota, withoutQuota []weighted
+func partitionByQuota(items []weighted) (withQuota, withoutQuota []weighted) {
for _, item := range items {
if item.weight > 0 {
withQuota = append(withQuota, item)
@@ -184,18 +182,26 @@ func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []Provider
withoutQuota = append(withoutQuota, item)
}
}
+ return
+}
- // Within quota group: weighted random sort (higher remaining = more likely first)
- if len(withQuota) > 1 {
- sort.Slice(withQuota, func(i, j int) bool {
- wi := float64(withQuota[i].weight) * (0.5 + rand.Float64())
- wj := float64(withQuota[j].weight) * (0.5 + rand.Float64())
- return wi > wj
- })
+// sortByStableRandomWeight assigns a fixed random factor to each item before sorting,
+// ensuring deterministic sort behavior (transitivity) within a single call.
+func sortByStableRandomWeight(items []weighted) {
+ if len(items) <= 1 {
+ return
}
+ factors := make([]float64, len(items))
+ for i, item := range items {
+ factors[i] = float64(item.weight) * (0.5 + rand.Float64())
+ }
+ sort.Slice(items, func(i, j int) bool {
+ return factors[i] > factors[j]
+ })
+}
- // Build final order: quota providers first, then no-quota providers (original priority order)
- result := make([]ProviderConfig, 0, len(candidates))
+func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig {
+ result := make([]ProviderConfig, 0, capacity)
for _, item := range withQuota {
result = append(result, item.cfg)
}
@@ -294,8 +300,8 @@ func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool
slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type)
return true, false
}
- key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
- ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds())
+ key := quotaRedisKey(cfg.Type)
+ ttlSec := int(quotaTTLFromSubscription(cfg.SubscribedAt).Seconds())
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
if err != nil {
slog.Warn("websearch: quota Lua INCR failed, allowing request",
@@ -318,7 +324,7 @@ func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
if cfg.QuotaLimit <= 0 || m.redis == nil {
return
}
- key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
+ key := quotaRedisKey(cfg.Type)
if err := m.redis.Decr(ctx, key).Err(); err != nil {
slog.Warn("websearch: quota rollback DECR failed",
"provider", cfg.Type, "error", err)
@@ -327,6 +333,25 @@ func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
// --- Search execution ---
+// TestSearch executes a search using the first available provider without reserving quota.
+// Intended for admin test functionality only.
+func (m *Manager) TestSearch(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
+ if strings.TrimSpace(req.Query) == "" {
+ return nil, "", fmt.Errorf("websearch: empty search query")
+ }
+ for _, cfg := range m.configs {
+ if !m.isProviderAvailable(cfg) {
+ continue
+ }
+ resp, err := m.executeSearch(ctx, cfg, req)
+ if err != nil {
+ continue
+ }
+ return resp, cfg.Type, nil
+ }
+ return nil, "", fmt.Errorf("websearch: no available provider")
+}
+
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
proxyURL := cfg.ProxyURL
if req.ProxyURL != "" {
@@ -384,11 +409,11 @@ func newHTTPClient(proxyURL string) (*http.Client, error) {
}
// GetUsage returns the current usage count for the given provider.
-func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) {
+func (m *Manager) GetUsage(ctx context.Context, providerType string) (int64, error) {
if m.redis == nil {
return 0, nil
}
- key := quotaRedisKey(providerType, refreshInterval)
+ key := quotaRedisKey(providerType)
val, err := m.redis.Get(ctx, key).Int64()
if err == redis.Nil {
return 0, nil
@@ -400,7 +425,7 @@ func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval st
func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
result := make(map[string]int64, len(m.configs))
for _, cfg := range m.configs {
- used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
+ used, _ := m.GetUsage(ctx, cfg.Type)
result[cfg.Type] = used
}
return result
@@ -423,30 +448,56 @@ func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provide
// --- Redis key helpers ---
-func quotaRedisKey(providerType, refreshInterval string) string {
- return quotaKeyPrefix + providerType + ":" + periodKey(refreshInterval)
+func quotaRedisKey(providerType string) string {
+ return quotaKeyPrefix + providerType
}
-func periodKey(refreshInterval string) string {
+// quotaTTLFromSubscription calculates the TTL for the quota counter based on
+// the provider's subscription start date. Quota resets monthly from that date.
+// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh).
+func quotaTTLFromSubscription(subscribedAt *int64) time.Duration {
+ if subscribedAt == nil || *subscribedAt == 0 {
+ return defaultQuotaTTL
+ }
+ next := nextMonthlyReset(time.Unix(*subscribedAt, 0).UTC())
+ ttl := time.Until(next) + quotaTTLBuffer
+ if ttl <= quotaTTLBuffer {
+ // Already past the reset — next cycle
+ ttl = defaultQuotaTTL
+ }
+ return ttl
+}
+
+// nextMonthlyReset returns the next monthly reset time based on the subscription start date.
+// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc.
+// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3).
+func nextMonthlyReset(subscribedAt time.Time) time.Time {
now := time.Now().UTC()
- switch refreshInterval {
- case QuotaRefreshDaily:
- return now.Format("2006-01-02")
- case QuotaRefreshWeekly:
- year, week := now.ISOWeek()
- return fmt.Sprintf("%d-W%02d", year, week)
- default:
- return now.Format("2006-01")
+ if subscribedAt.IsZero() {
+ return now.AddDate(0, 1, 0)
}
+ months := (now.Year()-subscribedAt.Year())*12 + int(now.Month()-subscribedAt.Month())
+ if months < 0 {
+ months = 0
+ }
+ candidate := addMonthsClamped(subscribedAt, months)
+ if candidate.After(now) {
+ return candidate
+ }
+ return addMonthsClamped(subscribedAt, months+1)
}
-func quotaTTL(refreshInterval string) time.Duration {
- switch refreshInterval {
- case QuotaRefreshDaily:
- return 24*time.Hour + quotaTTLBuffer
- case QuotaRefreshWeekly:
- return 7*24*time.Hour + quotaTTLBuffer
- default:
- return 31*24*time.Hour + quotaTTLBuffer
+// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month.
+// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3).
+func addMonthsClamped(t time.Time, months int) time.Time {
+ y, m, d := t.Date()
+ targetMonth := time.Month(int(m) + months)
+ targetYear := y + int(targetMonth-1)/12
+ targetMonth = (targetMonth-1)%12 + 1
+ // Last day of the target month
+ lastDay := time.Date(targetYear, targetMonth+1, 0, 0, 0, 0, 0, time.UTC).Day()
+ if d > lastDay {
+ d = lastDay
}
+ return time.Date(targetYear, targetMonth, d, 0, 0, 0, 0, time.UTC)
}
diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go
index d3cd29d6..a4beef68 100644
--- a/backend/internal/pkg/websearch/manager_test.go
+++ b/backend/internal/pkg/websearch/manager_test.go
@@ -12,14 +12,14 @@ import (
"github.com/stretchr/testify/require"
)
-func TestNewManager_SortsByPriority(t *testing.T) {
+func TestNewManager_PreservesOrder(t *testing.T) {
configs := []ProviderConfig{
- {Type: "brave", APIKey: "k3", Priority: 30},
- {Type: "tavily", APIKey: "k1", Priority: 10},
+ {Type: "brave", APIKey: "k3"},
+ {Type: "tavily", APIKey: "k1"},
}
m := NewManager(configs, nil)
- require.Equal(t, 10, m.configs[0].Priority)
- require.Equal(t, 30, m.configs[1].Priority)
+ require.Equal(t, "brave", m.configs[0].Type)
+ require.Equal(t, "tavily", m.configs[1].Type)
}
func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
@@ -46,8 +46,7 @@ func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
require.ErrorContains(t, err, "no available provider")
}
-func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
- // Create two mock servers that return different results
+func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) {
srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
resp := braveResponse{}
resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
@@ -55,17 +54,15 @@ func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
}))
defer srvBrave.Close()
- // Override brave endpoint for test
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srvBrave.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
m := NewManager([]ProviderConfig{
- {Type: "brave", APIKey: "k1", Priority: 1},
- {Type: "tavily", APIKey: "k2", Priority: 2},
+ {Type: "brave", APIKey: "k1"},
+ {Type: "tavily", APIKey: "k2"},
}, nil)
- // Inject the test server's client
m.clientCache[srvBrave.URL] = srvBrave.Client()
m.clientCache[""] = srvBrave.Client()
@@ -77,7 +74,6 @@ func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
}
func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
- // With nil Redis, quota check is skipped (always allowed)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
resp := braveResponse{}
resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
@@ -91,8 +87,8 @@ func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
defer func() { *braveSearchURL = origURL }()
m := NewManager([]ProviderConfig{
- {Type: "brave", APIKey: "k", Priority: 1, QuotaLimit: 100},
- }, nil) // nil Redis
+ {Type: "brave", APIKey: "k", QuotaLimit: 100},
+ }, nil)
m.clientCache[""] = srv.Client()
resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
@@ -102,51 +98,98 @@ func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
func TestManager_GetUsage_NilRedis(t *testing.T) {
m := NewManager(nil, nil)
- used, err := m.GetUsage(context.Background(), "brave", "monthly")
+ used, err := m.GetUsage(context.Background(), "brave")
require.NoError(t, err)
require.Equal(t, int64(0), used)
}
func TestManager_GetAllUsage_NilRedis(t *testing.T) {
m := NewManager([]ProviderConfig{
- {Type: "brave", QuotaRefreshInterval: "monthly"},
+ {Type: "brave"},
}, nil)
usage := m.GetAllUsage(context.Background())
require.Equal(t, int64(0), usage["brave"])
}
-// --- Key/TTL helpers ---
+// --- Quota TTL from subscription ---
-func TestQuotaTTL_Daily(t *testing.T) {
- require.Equal(t, 24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshDaily))
+func TestQuotaTTLFromSubscription_NilSubscription(t *testing.T) {
+ ttl := quotaTTLFromSubscription(nil)
+ require.Equal(t, defaultQuotaTTL, ttl)
}
-func TestQuotaTTL_Weekly(t *testing.T) {
- require.Equal(t, 7*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshWeekly))
+func TestQuotaTTLFromSubscription_ZeroSubscription(t *testing.T) {
+ zero := int64(0)
+ ttl := quotaTTLFromSubscription(&zero)
+ require.Equal(t, defaultQuotaTTL, ttl)
}
-func TestQuotaTTL_Monthly(t *testing.T) {
- require.Equal(t, 31*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshMonthly))
+func TestQuotaTTLFromSubscription_ValidSubscription(t *testing.T) {
+ // Subscribed 10 days ago — next reset in ~20 days
+ sub := time.Now().Add(-10 * 24 * time.Hour).Unix()
+ ttl := quotaTTLFromSubscription(&sub)
+ require.Greater(t, ttl, 15*24*time.Hour) // at least 15 days
+ require.Less(t, ttl, 25*24*time.Hour+quotaTTLBuffer)
}
-func TestPeriodKey_Daily(t *testing.T) {
- key := periodKey(QuotaRefreshDaily)
- require.Regexp(t, `^\d{4}-\d{2}-\d{2}$`, key)
+func TestNextMonthlyReset_SubscribedRecentPast(t *testing.T) {
+ // Subscribed on the 10th of this month (always valid day)
+ now := time.Now().UTC()
+ sub := time.Date(now.Year(), now.Month(), 10, 0, 0, 0, 0, time.UTC)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(now) || next.Equal(now), "next reset should be in the future or now")
+ require.True(t, next.Before(now.AddDate(0, 1, 1)))
}
-func TestPeriodKey_Weekly(t *testing.T) {
- key := periodKey(QuotaRefreshWeekly)
- require.Regexp(t, `^\d{4}-W\d{2}$`, key)
+func TestNextMonthlyReset_SubscribedLongAgo(t *testing.T) {
+ // Subscribed 6 months ago on the 1st
+ sub := time.Now().UTC().AddDate(0, -6, 0)
+ sub = time.Date(sub.Year(), sub.Month(), 1, 0, 0, 0, 0, time.UTC)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(time.Now().UTC()))
+ // Should be within the next 31 days
+ require.True(t, next.Before(time.Now().UTC().AddDate(0, 1, 1)))
}
-func TestPeriodKey_Monthly(t *testing.T) {
- key := periodKey(QuotaRefreshMonthly)
- require.Regexp(t, `^\d{4}-\d{2}$`, key)
+func TestNextMonthlyReset_FutureSubscription(t *testing.T) {
+ sub := time.Now().UTC().AddDate(0, 0, 5)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(time.Now().UTC()))
}
+func TestAddMonthsClamped_Jan31ToFeb(t *testing.T) {
+ sub := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 28, next.Day()) // Feb 28 (2026 is not a leap year)
+}
+
+func TestAddMonthsClamped_Jan31ToFebLeapYear(t *testing.T) {
+ sub := time.Date(2028, 1, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 29, next.Day()) // Feb 29 (2028 is a leap year)
+}
+
+func TestAddMonthsClamped_Mar31ToApr(t *testing.T) {
+ sub := time.Date(2026, 3, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(4), next.Month())
+ require.Equal(t, 30, next.Day()) // Apr has 30 days
+}
+
+func TestAddMonthsClamped_NormalDay(t *testing.T) {
+ sub := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 15, next.Day()) // no clamping needed
+}
+
+// --- Redis key ---
+
func TestQuotaRedisKey_Format(t *testing.T) {
- key := quotaRedisKey("brave", QuotaRefreshDaily)
- require.Contains(t, key, "websearch:quota:brave:")
+ key := quotaRedisKey("brave")
+ require.Equal(t, "websearch:quota:brave", key)
}
// --- isProviderAvailable ---
@@ -173,9 +216,7 @@ func TestIsProviderAvailable_Valid(t *testing.T) {
func TestResolveProxyID_AccountProxyOverrides(t *testing.T) {
cfg := ProviderConfig{ProxyID: 42}
- // account proxy present → return 0 (account proxy has no config-level ID)
require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080"))
- // no account proxy → return provider's proxy ID
require.Equal(t, int64(42), resolveProxyID(cfg, ""))
}
@@ -186,28 +227,23 @@ func TestIsProxyError_Nil(t *testing.T) {
}
func TestIsProxyError_ConnectionRefused(t *testing.T) {
- err := fmt.Errorf("dial tcp: connection refused")
- require.True(t, isProxyError(err))
+ require.True(t, isProxyError(fmt.Errorf("dial tcp: connection refused")))
}
func TestIsProxyError_Timeout(t *testing.T) {
- err := fmt.Errorf("i/o timeout while connecting to proxy")
- require.True(t, isProxyError(err))
+ require.True(t, isProxyError(fmt.Errorf("i/o timeout while connecting to proxy")))
}
func TestIsProxyError_SOCKS(t *testing.T) {
- err := fmt.Errorf("socks connect failed")
- require.True(t, isProxyError(err))
+ require.True(t, isProxyError(fmt.Errorf("socks connect failed")))
}
func TestIsProxyError_TLSHandshake(t *testing.T) {
- err := fmt.Errorf("tls handshake timeout")
- require.True(t, isProxyError(err))
+ require.True(t, isProxyError(fmt.Errorf("tls handshake timeout")))
}
func TestIsProxyError_APIError_NotProxy(t *testing.T) {
- err := fmt.Errorf("API rate limit exceeded")
- require.False(t, isProxyError(err))
+ require.False(t, isProxyError(fmt.Errorf("API rate limit exceeded")))
}
// --- isProxyAvailable (nil Redis) ---
@@ -225,14 +261,13 @@ func TestIsProxyAvailable_ZeroID(t *testing.T) {
// --- selectByQuotaWeight ---
func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) {
- m := NewManager(nil, nil) // nil Redis → GetUsage returns 0
+ m := NewManager(nil, nil)
candidates := []ProviderConfig{
- {Type: "brave", APIKey: "k1", QuotaLimit: 0}, // no limit → weight 0
- {Type: "tavily", APIKey: "k2", QuotaLimit: 100}, // remaining 100
+ {Type: "brave", APIKey: "k1", QuotaLimit: 0},
+ {Type: "tavily", APIKey: "k2", QuotaLimit: 100},
}
result := m.selectByQuotaWeight(context.Background(), candidates)
require.Len(t, result, 2)
- // tavily (with quota) should come first
require.Equal(t, "tavily", result[0].Type)
require.Equal(t, "brave", result[1].Type)
}
@@ -245,7 +280,6 @@ func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) {
}
result := m.selectByQuotaWeight(context.Background(), candidates)
require.Len(t, result, 2)
- // both have weight 0, original order preserved
}
func TestSelectByQuotaWeight_Empty(t *testing.T) {
diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go
index a8034e98..ba45c31b 100644
--- a/backend/internal/server/http.go
+++ b/backend/internal/server/http.go
@@ -2,12 +2,14 @@
package server
import (
+ "context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -56,6 +58,34 @@ func ProvideRouter(
}
}
+ // Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
+ settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig) {
+ if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 {
+ service.SetWebSearchManager(nil)
+ return
+ }
+ configs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
+ for _, p := range cfg.Providers {
+ if p.APIKey == "" {
+ continue
+ }
+ pc := websearch.ProviderConfig{
+ Type: p.Type,
+ APIKey: p.APIKey,
+ QuotaLimit: p.QuotaLimit,
+ ExpiresAt: p.ExpiresAt,
+ }
+ if p.SubscribedAt != nil {
+ pc.SubscribedAt = p.SubscribedAt
+ }
+ if p.ProxyID != nil {
+ pc.ProxyID = *p.ProxyID
+ }
+ configs = append(configs, pc)
+ }
+ service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
+ })
+
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 7c4e6cb7..0a7b7a8b 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -410,6 +410,7 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Web Search 模拟配置
adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
+ adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation)
}
}
diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go
index 15ec1f9d..bdfff3e4 100644
--- a/backend/internal/service/websearch_config.go
+++ b/backend/internal/service/websearch_config.go
@@ -22,15 +22,14 @@ type WebSearchEmulationConfig struct {
// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
type WebSearchProviderConfig struct {
- Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
- APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
- APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
- Priority int `json:"priority"` // lower = higher priority
- QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
- QuotaRefreshInterval string `json:"quota_refresh_interval"` // websearch.QuotaRefresh*
- QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current period usage
- ProxyID *int64 `json:"proxy_id"` // optional proxy association
- ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
+ Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
+ APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
+ APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
+ QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
+ SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly
+ QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current usage from Redis
+ ProxyID *int64 `json:"proxy_id"` // optional proxy association
+ ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
}
// --- Validation ---
@@ -42,13 +41,6 @@ var validProviderTypes = map[string]bool{
websearch.ProviderTypeTavily: true,
}
-var validQuotaIntervals = map[string]bool{
- websearch.QuotaRefreshDaily: true,
- websearch.QuotaRefreshWeekly: true,
- websearch.QuotaRefreshMonthly: true,
- "": true, // defaults to monthly
-}
-
func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
if cfg == nil {
return nil
@@ -61,9 +53,6 @@ func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
if !validProviderTypes[p.Type] {
return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type)
}
- if !validQuotaIntervals[p.QuotaRefreshInterval] {
- return fmt.Errorf("provider[%d]: invalid quota_refresh_interval %q", i, p.QuotaRefreshInterval)
- }
if p.QuotaLimit < 0 {
return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i)
}
@@ -237,17 +226,55 @@ func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
}
-// SanitizeWebSearchConfig returns a copy with api_key fields masked for API responses.
-func SanitizeWebSearchConfig(cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
+// WebSearchTestResult holds the result of a search test.
+type WebSearchTestResult struct {
+ Provider string `json:"provider"`
+ Results []websearch.SearchResult `json:"results"`
+ Query string `json:"query"`
+}
+
+// TestWebSearch executes a test search using the currently configured Manager.
+// Uses Manager.TestSearch which bypasses quota tracking.
+func TestWebSearch(ctx context.Context, query string) (*WebSearchTestResult, error) {
+ mgr := getWebSearchManager()
+ if mgr == nil {
+ return nil, fmt.Errorf("web search: manager not initialized, save config first")
+ }
+ resp, providerName, err := mgr.TestSearch(ctx, websearch.SearchRequest{
+ Query: query,
+ MaxResults: webSearchDefaultMaxResults,
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &WebSearchTestResult{
+ Provider: providerName,
+ Results: resp.Results,
+ Query: resp.Query,
+ }, nil
+}
+
+// SanitizeWebSearchConfig returns a copy with api_key fields masked and quota usage populated.
+func SanitizeWebSearchConfig(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
if cfg == nil {
return nil
}
out := *cfg
out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
+
+ // Load usage from the global Manager (reads from Redis)
+ mgr := getWebSearchManager()
+
for i, p := range cfg.Providers {
out.Providers[i] = p
out.Providers[i].APIKeyConfigured = p.APIKey != ""
out.Providers[i].APIKey = "" // never return the secret
+
+ // Populate quota usage from Redis
+ if mgr != nil {
+ used, _ := mgr.GetUsage(ctx, p.Type)
+ out.Providers[i].QuotaUsed = used
+ }
}
return &out
}
diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go
index 1a19dd9d..4aea98b7 100644
--- a/backend/internal/service/websearch_config_test.go
+++ b/backend/internal/service/websearch_config_test.go
@@ -1,6 +1,7 @@
package service
import (
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -16,8 +17,8 @@ func TestValidateWebSearchConfig_Valid(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
- {Type: "brave", Priority: 1, QuotaLimit: 1000, QuotaRefreshInterval: "monthly"},
- {Type: "tavily", Priority: 2, QuotaLimit: 500, QuotaRefreshInterval: "daily"},
+ {Type: "brave", QuotaLimit: 1000},
+ {Type: "tavily", QuotaLimit: 500},
},
}
require.NoError(t, validateWebSearchConfig(cfg))
@@ -39,13 +40,6 @@ func TestValidateWebSearchConfig_InvalidType(t *testing.T) {
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type")
}
-func TestValidateWebSearchConfig_InvalidQuotaInterval(t *testing.T) {
- cfg := &WebSearchEmulationConfig{
- Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: "hourly"}},
- }
- require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid quota_refresh_interval")
-}
-
func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}},
@@ -56,20 +50,13 @@ func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
func TestValidateWebSearchConfig_DuplicateType(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
- {Type: "brave", Priority: 1},
- {Type: "brave", Priority: 2},
+ {Type: "brave"},
+ {Type: "brave"},
},
}
require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type")
}
-func TestValidateWebSearchConfig_EmptyQuotaInterval(t *testing.T) {
- cfg := &WebSearchEmulationConfig{
- Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: ""}},
- }
- require.NoError(t, validateWebSearchConfig(cfg))
-}
-
func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}},
@@ -99,6 +86,15 @@ func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) {
require.Empty(t, cfg.Providers)
}
+func TestParseWebSearchConfigJSON_BackwardCompatibility(t *testing.T) {
+ // Old config with priority and quota_refresh_interval should parse without error
+ raw := `{"enabled":true,"providers":[{"type":"brave","priority":1,"quota_refresh_interval":"monthly","quota_limit":1000}]}`
+ cfg := parseWebSearchConfigJSON(raw)
+ require.True(t, cfg.Enabled)
+ require.Len(t, cfg.Providers, 1)
+ require.Equal(t, int64(1000), cfg.Providers[0].QuotaLimit)
+}
+
// --- SanitizeWebSearchConfig ---
func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
@@ -108,7 +104,7 @@ func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
{Type: "brave", APIKey: "sk-secret-xxx"},
},
}
- out := SanitizeWebSearchConfig(cfg)
+ out := SanitizeWebSearchConfig(context.Background(), cfg)
require.Equal(t, "", out.Providers[0].APIKey)
require.True(t, out.Providers[0].APIKeyConfigured)
}
@@ -117,25 +113,24 @@ func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}},
}
- out := SanitizeWebSearchConfig(cfg)
+ out := SanitizeWebSearchConfig(context.Background(), cfg)
require.Equal(t, "", out.Providers[0].APIKey)
require.False(t, out.Providers[0].APIKeyConfigured)
}
func TestSanitizeWebSearchConfig_Nil(t *testing.T) {
- require.Nil(t, SanitizeWebSearchConfig(nil))
+ require.Nil(t, SanitizeWebSearchConfig(context.Background(), nil))
}
func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
- {Type: "brave", APIKey: "secret", Priority: 10, QuotaLimit: 1000},
+ {Type: "brave", APIKey: "secret", QuotaLimit: 1000},
},
}
- out := SanitizeWebSearchConfig(cfg)
+ out := SanitizeWebSearchConfig(context.Background(), cfg)
require.True(t, out.Enabled)
- require.Equal(t, 10, out.Providers[0].Priority)
require.Equal(t, int64(1000), out.Providers[0].QuotaLimit)
}
@@ -143,6 +138,6 @@ func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}},
}
- _ = SanitizeWebSearchConfig(cfg)
+ _ = SanitizeWebSearchConfig(context.Background(), cfg)
require.Equal(t, "secret", cfg.Providers[0].APIKey)
}
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index c6323b00..31284289 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -497,9 +497,8 @@ export interface WebSearchProviderConfig {
type: 'brave' | 'tavily'
api_key: string
api_key_configured: boolean
- priority: number
quota_limit: number
- quota_refresh_interval: 'daily' | 'weekly' | 'monthly'
+ subscribed_at: number | null
quota_used?: number
proxy_id: number | null
expires_at: number | null
@@ -510,6 +509,12 @@ export interface WebSearchEmulationConfig {
providers: WebSearchProviderConfig[]
}
+export interface WebSearchTestResult {
+ provider: string
+ results: { url: string; title: string; snippet: string; page_age?: string }[]
+ query: string
+}
+
export async function getWebSearchEmulationConfig(): Promise {
const { data } = await apiClient.get(
'/admin/settings/web-search-emulation'
@@ -527,6 +532,16 @@ export async function updateWebSearchEmulationConfig(
return data
}
+export async function testWebSearchEmulation(
+ query: string
+): Promise {
+ const { data } = await apiClient.post(
+ '/admin/settings/web-search-emulation/test',
+ { query }
+ )
+ return data
+}
+
export const settingsAPI = {
getSettings,
updateSettings,
@@ -544,7 +559,8 @@ export const settingsAPI = {
getBetaPolicySettings,
updateBetaPolicySettings,
getWebSearchEmulationConfig,
- updateWebSearchEmulationConfig
+ updateWebSearchEmulationConfig,
+ testWebSearchEmulation
}
export default settingsAPI
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 7119fa36..8e10bf2a 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -4417,19 +4417,24 @@ export default {
apiKey: 'API Key',
apiKeyPlaceholder: 'Enter API Key',
apiKeyConfigured: 'Configured',
- priority: 'Priority',
- priorityHint: 'Lower number = higher priority',
+ showApiKey: 'Show',
+ hideApiKey: 'Hide',
+ copyApiKey: 'Copy',
+ copied: 'Copied',
quotaLimit: 'Quota Limit',
quotaLimitHint: '0 = unlimited',
- quotaRefreshInterval: 'Refresh Interval',
- quotaUsed: 'Used',
+ subscribedAt: 'Subscribed At',
+ subscribedAtHint: 'Quota resets monthly from this date',
+ quotaUsage: 'Usage',
proxy: 'Proxy',
- expiresAt: 'Expires At',
removeProvider: 'Remove',
- daily: 'Daily',
- weekly: 'Weekly',
- monthly: 'Monthly',
noProviders: 'No search providers configured',
+ test: 'Test',
+ testDefaultQuery: 'Major world events this year',
+ testing: 'Searching...',
+ testResultTitle: 'Search Results',
+ testResultProvider: 'Provider',
+ testNoResults: 'No results found',
},
site: {
title: 'Site Settings',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 6efaf657..1b82f419 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -4579,19 +4579,24 @@ export default {
apiKey: 'API Key',
apiKeyPlaceholder: '输入 API Key',
apiKeyConfigured: '已配置',
- priority: '优先级',
- priorityHint: '数值越小优先级越高',
+ showApiKey: '显示',
+ hideApiKey: '隐藏',
+ copyApiKey: '复制',
+ copied: '已复制',
quotaLimit: '配额上限',
quotaLimitHint: '0 表示无限制',
- quotaRefreshInterval: '刷新周期',
- quotaUsed: '已使用',
+ subscribedAt: '订阅时间',
+ subscribedAtHint: '配额从此日期起每月自动重置',
+ quotaUsage: '用量',
proxy: '代理',
- expiresAt: '过期时间',
removeProvider: '删除',
- daily: '每日',
- weekly: '每周',
- monthly: '每月',
noProviders: '未配置搜索服务商',
+ test: '测试',
+ testDefaultQuery: '搜索今年世界大事件',
+ testing: '搜索中...',
+ testResultTitle: '搜索结果',
+ testResultProvider: '服务商',
+ testNoResults: '无搜索结果',
},
site: {
title: '站点设置',
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index f57bfcf3..8ed77203 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -1751,61 +1751,152 @@
-
-
-
+ class="rounded-lg border border-gray-200 dark:border-dark-600">
+
+
+
+
+
+
+
+
+
+ {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }}
+
+
+ {{ t('admin.settings.webSearchEmulation.apiKeyConfigured') }}
+
+
+
{{ t('admin.settings.webSearchEmulation.removeProvider') }}
-
+
+
+
{{ t('admin.settings.webSearchEmulation.apiKey') }}
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
{{ t('admin.settings.webSearchEmulation.priority') }}
-
-
{{ t('admin.settings.webSearchEmulation.priorityHint') }}
-
-
-
{{ t('admin.settings.webSearchEmulation.quotaLimit') }}
-
-
{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}
-
- {{ t('admin.settings.webSearchEmulation.quotaUsed') }}: {{ provider.quota_used }} / {{ provider.quota_limit || '∞' }}
-
-
-
- {{ t('admin.settings.webSearchEmulation.quotaRefreshInterval') }}
-
-
-
-
-
{{ t('admin.settings.webSearchEmulation.proxy') }}
-
+
+
+
+
{{ t('admin.settings.webSearchEmulation.quotaLimit') }}
+
+
{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}
+
+
+
{{ t('admin.settings.webSearchEmulation.subscribedAt') }}
+
+
{{ t('admin.settings.webSearchEmulation.subscribedAtHint') }}
+
+
+
+
+
+
{{ t('admin.settings.webSearchEmulation.quotaUsage') }}:
+
+
{{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }}
+
+
+
+
+
{{ t('admin.settings.webSearchEmulation.proxy') }}
+
+
+
+
+
+
+
+
+ {{ wsTestLoading ? t('admin.settings.webSearchEmulation.testing') : t('admin.settings.webSearchEmulation.test') }}
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }}
+
+
+ {{ t('admin.settings.webSearchEmulation.testNoResults') }}
+
+
+
{{ r.title }}
+
{{ r.snippet && r.snippet.length > 120 ? r.snippet.slice(0, 120) + '...' : r.snippet }}
+
+
+
@@ -2303,6 +2394,13 @@
]"
>{{ pt.label }}
+
+ {{ t('admin.settings.payment.enabledPaymentTypesHint') }}
+
+ {{ t('admin.settings.payment.findProvider') }}
+
+
+
@@ -2562,60 +2660,6 @@
-
-
-
-
- {{ t('admin.settings.balanceNotify.title') }}
-
-
- {{ t('admin.settings.balanceNotify.description') }}
-
-
-
-
- {{ t('admin.settings.balanceNotify.enabled') }}
-
-
-
-
{{ t('admin.settings.balanceNotify.threshold') }}
-
- $
-
-
-
{{ t('admin.settings.balanceNotify.thresholdHint') }}
-
-
-
-
-
-
-
-
- {{ t('admin.settings.quotaNotify.title') }}
-
-
- {{ t('admin.settings.quotaNotify.description') }}
-
-
-
-
-
{{ t('admin.settings.quotaNotify.emails') }}
-
-
-
-
-
-
-
-
- + {{ t('admin.settings.quotaNotify.addEmail') }}
-
-
-
{{ t('admin.settings.quotaNotify.emailsHint') }}
-
-
-
@@ -2674,8 +2718,9 @@ import type {
DefaultSubscriptionSetting,
WebSearchEmulationConfig,
WebSearchProviderConfig,
+ WebSearchTestResult,
} from '@/api/admin/settings'
-import type { AdminGroup } from '@/types'
+import type { AdminGroup, Proxy } from '@/types'
import type { ProviderInstance } from '@/types/payment'
import AppLayout from '@/components/layout/AppLayout.vue'
import Icon from '@/components/icons/Icon.vue'
@@ -2894,13 +2939,12 @@ const form = reactive({
// Gateway forwarding behavior
enable_fingerprint_unification: true,
enable_metadata_passthrough: false,
- enable_cch_signing: false,
- // Balance & quota notification
- balance_low_notify_enabled: false,
- balance_low_notify_threshold: 0,
- account_quota_notify_emails: [] as string[]
+ enable_cch_signing: false
})
+// Proxies for web search emulation ProxySelector
+const webSearchProxies = ref([])
+
// Web Search Emulation config (loaded/saved separately)
const DEFAULT_WEB_SEARCH_QUOTA_LIMIT = 1000
@@ -2909,26 +2953,101 @@ const webSearchConfig = reactive({
providers: [],
})
+const expandedProviders = reactive>({})
+const apiKeyVisible = reactive>({})
+const wsTestQuery = ref('')
+const wsTestLoading = ref(false)
+const wsTestResult = ref(null)
+
+function toggleProviderExpand(idx: number) {
+ expandedProviders[idx] = !expandedProviders[idx]
+}
+
+function removeWebSearchProvider(idx: number) {
+ webSearchConfig.providers.splice(idx, 1)
+ // Re-index expandedProviders and apiKeyVisible after removal
+ const newExpanded: Record = {}
+ const newVisible: Record = {}
+ for (let i = 0; i < webSearchConfig.providers.length; i++) {
+ const oldIdx = i >= idx ? i + 1 : i
+ newExpanded[i] = expandedProviders[oldIdx] ?? false
+ newVisible[i] = apiKeyVisible[oldIdx] ?? false
+ }
+ Object.keys(expandedProviders).forEach((k) => delete expandedProviders[Number(k)])
+ Object.keys(apiKeyVisible).forEach((k) => delete apiKeyVisible[Number(k)])
+ Object.assign(expandedProviders, newExpanded)
+ Object.assign(apiKeyVisible, newVisible)
+}
+
function addWebSearchProvider() {
+ const idx = webSearchConfig.providers.length
webSearchConfig.providers.push({
type: 'brave',
api_key: '',
api_key_configured: false,
- priority: webSearchConfig.providers.length + 1,
quota_limit: DEFAULT_WEB_SEARCH_QUOTA_LIMIT,
- quota_refresh_interval: 'monthly',
+ subscribed_at: null,
proxy_id: null,
expires_at: null,
} as WebSearchProviderConfig)
+ expandedProviders[idx] = true
+}
+
+function formatSubscribedAt(ts: number | null): string {
+ if (!ts) return ''
+ // Use UTC to avoid timezone drift on repeated edits
+ const d = new Date(ts * 1000)
+ const y = d.getUTCFullYear()
+ const m = String(d.getUTCMonth() + 1).padStart(2, '0')
+ const day = String(d.getUTCDate()).padStart(2, '0')
+ return `${y}-${m}-${day}`
+}
+
+function parseSubscribedAt(dateStr: string): number | null {
+ if (!dateStr) return null
+ // Parse as UTC to match formatSubscribedAt
+ return Math.floor(new Date(dateStr + 'T00:00:00Z').getTime() / 1000)
+}
+
+function quotaPercentage(provider: WebSearchProviderConfig): number {
+ if (!provider.quota_limit || provider.quota_limit <= 0) return 0
+ return ((provider.quota_used ?? 0) / provider.quota_limit) * 100
+}
+
+async function copyApiKey(idx: number) {
+ const key = webSearchConfig.providers[idx]?.api_key
+ if (!key) {
+ appStore.showError(t('admin.settings.webSearchEmulation.apiKeyPlaceholder'))
+ return
+ }
+ await navigator.clipboard.writeText(key)
+ appStore.showSuccess(t('admin.settings.webSearchEmulation.copied'))
+}
+
+async function testWebSearchProvider() {
+ wsTestLoading.value = true
+ wsTestResult.value = null
+ try {
+ const query = wsTestQuery.value.trim() || t('admin.settings.webSearchEmulation.testDefaultQuery')
+ wsTestResult.value = await adminAPI.settings.testWebSearchEmulation(query)
+ } catch (err: unknown) {
+ appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ } finally {
+ wsTestLoading.value = false
+ }
}
async function loadWebSearchConfig() {
try {
- const resp = await adminAPI.settings.getWebSearchEmulationConfig()
+ const [resp, proxiesResp] = await Promise.all([
+ adminAPI.settings.getWebSearchEmulationConfig(),
+ adminAPI.proxies.list().catch(() => ({ items: [] as Proxy[] })),
+ ])
if (resp) {
webSearchConfig.enabled = resp.enabled || false
webSearchConfig.providers = resp.providers || []
}
+ webSearchProxies.value = proxiesResp.items || []
} catch (err: unknown) {
// 404 is expected when config hasn't been created yet; show error for other failures
const status = (err as { status?: number })?.status
@@ -3030,14 +3149,6 @@ function handleRegistrationEmailSuffixWhitelistPaste(event: ClipboardEvent) {
}
}
-// Quota notify email helpers
-const addQuotaNotifyEmail = () => {
- if (!form.account_quota_notify_emails) {
- form.account_quota_notify_emails = []
- }
- form.account_quota_notify_emails.push('')
-}
-
// LinuxDo OAuth redirect URL suggestion
const linuxdoRedirectUrlSuggestion = computed(() => {
if (typeof window === 'undefined') return ''
@@ -3377,10 +3488,6 @@ async function saveSettings() {
payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1,
payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit,
payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode,
- // Balance & quota notification
- balance_low_notify_enabled: form.balance_low_notify_enabled,
- balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0,
- account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''),
}
const updated = await adminAPI.settings.updateSettings(payload)
From f694afbbf431122407fe7f4bfb3d7f4ee0e5654f Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 13:53:02 +0800
Subject: [PATCH 24/88] feat(notify): add percentage threshold type for balance
low notification
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Add threshold_type field (fixed/percentage) to system and user settings
- Add total_recharged field to users table, auto-incremented on balance credit
- Percentage mode: effective threshold = total_recharged × percentage / 100
- User-level threshold_type inherits from system default when not set
- Update admin settings UI with radio selector (fixed amount / percentage)
- Migration: 102_add_balance_notify_threshold_type.sql
---
backend/ent/migrate/schema.go | 2 +
backend/ent/mutation.go | 143 ++++++++++++++++-
backend/ent/runtime/runtime.go | 10 +-
backend/ent/schema/user.go | 5 +
backend/ent/user.go | 26 ++-
backend/ent/user/user.go | 20 +++
backend/ent/user/where.go | 115 ++++++++++++++
backend/ent/user_create.go | 150 ++++++++++++++++++
backend/ent/user_update.go | 88 ++++++++++
.../internal/handler/admin/setting_handler.go | 38 +++++
backend/internal/handler/dto/mappers.go | 28 ++--
backend/internal/handler/dto/settings.go | 7 +-
backend/internal/handler/dto/types.go | 8 +-
backend/internal/handler/user_handler.go | 14 +-
backend/internal/repository/api_key_repo.go | 34 ++--
backend/internal/repository/user_repo.go | 8 +-
.../service/balance_notify_service.go | 51 ++++--
backend/internal/service/domain_constants.go | 9 +-
backend/internal/service/setting_service.go | 9 ++
backend/internal/service/settings_view.go | 5 +-
backend/internal/service/user.go | 8 +-
backend/internal/service/user_service.go | 14 +-
.../102_add_balance_notify_threshold_type.sql | 4 +
frontend/src/api/admin/settings.ts | 2 +
frontend/src/i18n/locales/en.ts | 4 +
frontend/src/i18n/locales/zh.ts | 6 +-
frontend/src/views/admin/SettingsView.vue | 104 +++++++++++-
27 files changed, 838 insertions(+), 74 deletions(-)
create mode 100644 backend/migrations/102_add_balance_notify_threshold_type.sql
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 4f31883b..1fff61ba 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -1079,8 +1079,10 @@ var (
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
+ {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index cdaf363a..3bca248d 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -28211,9 +28211,12 @@ type UserMutation struct {
totp_enabled *bool
totp_enabled_at *time.Time
balance_notify_enabled *bool
+ balance_notify_threshold_type *string
balance_notify_threshold *float64
addbalance_notify_threshold *float64
balance_notify_extra_emails *string
+ total_recharged *float64
+ addtotal_recharged *float64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -28967,6 +28970,42 @@ func (m *UserMutation) ResetBalanceNotifyEnabled() {
m.balance_notify_enabled = nil
}
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (m *UserMutation) SetBalanceNotifyThresholdType(s string) {
+ m.balance_notify_threshold_type = &s
+}
+
+// BalanceNotifyThresholdType returns the value of the "balance_notify_threshold_type" field in the mutation.
+func (m *UserMutation) BalanceNotifyThresholdType() (r string, exists bool) {
+ v := m.balance_notify_threshold_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyThresholdType returns the old "balance_notify_threshold_type" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyThresholdType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyThresholdType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyThresholdType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyThresholdType: %w", err)
+ }
+ return oldValue.BalanceNotifyThresholdType, nil
+}
+
+// ResetBalanceNotifyThresholdType resets all changes to the "balance_notify_threshold_type" field.
+func (m *UserMutation) ResetBalanceNotifyThresholdType() {
+ m.balance_notify_threshold_type = nil
+}
+
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (m *UserMutation) SetBalanceNotifyThreshold(f float64) {
m.balance_notify_threshold = &f
@@ -29073,6 +29112,62 @@ func (m *UserMutation) ResetBalanceNotifyExtraEmails() {
m.balance_notify_extra_emails = nil
}
+// SetTotalRecharged sets the "total_recharged" field.
+func (m *UserMutation) SetTotalRecharged(f float64) {
+ m.total_recharged = &f
+ m.addtotal_recharged = nil
+}
+
+// TotalRecharged returns the value of the "total_recharged" field in the mutation.
+func (m *UserMutation) TotalRecharged() (r float64, exists bool) {
+ v := m.total_recharged
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotalRecharged returns the old "total_recharged" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotalRecharged(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotalRecharged is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotalRecharged requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotalRecharged: %w", err)
+ }
+ return oldValue.TotalRecharged, nil
+}
+
+// AddTotalRecharged adds f to the "total_recharged" field.
+func (m *UserMutation) AddTotalRecharged(f float64) {
+ if m.addtotal_recharged != nil {
+ *m.addtotal_recharged += f
+ } else {
+ m.addtotal_recharged = &f
+ }
+}
+
+// AddedTotalRecharged returns the value that was added to the "total_recharged" field in this mutation.
+func (m *UserMutation) AddedTotalRecharged() (r float64, exists bool) {
+ v := m.addtotal_recharged
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetTotalRecharged resets all changes to the "total_recharged" field.
+func (m *UserMutation) ResetTotalRecharged() {
+ m.total_recharged = nil
+ m.addtotal_recharged = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -29647,7 +29742,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 17)
+ fields := make([]string, 0, 19)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29693,12 +29788,18 @@ func (m *UserMutation) Fields() []string {
if m.balance_notify_enabled != nil {
fields = append(fields, user.FieldBalanceNotifyEnabled)
}
+ if m.balance_notify_threshold_type != nil {
+ fields = append(fields, user.FieldBalanceNotifyThresholdType)
+ }
if m.balance_notify_threshold != nil {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
if m.balance_notify_extra_emails != nil {
fields = append(fields, user.FieldBalanceNotifyExtraEmails)
}
+ if m.total_recharged != nil {
+ fields = append(fields, user.FieldTotalRecharged)
+ }
return fields
}
@@ -29737,10 +29838,14 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabledAt()
case user.FieldBalanceNotifyEnabled:
return m.BalanceNotifyEnabled()
+ case user.FieldBalanceNotifyThresholdType:
+ return m.BalanceNotifyThresholdType()
case user.FieldBalanceNotifyThreshold:
return m.BalanceNotifyThreshold()
case user.FieldBalanceNotifyExtraEmails:
return m.BalanceNotifyExtraEmails()
+ case user.FieldTotalRecharged:
+ return m.TotalRecharged()
}
return nil, false
}
@@ -29780,10 +29885,14 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabledAt(ctx)
case user.FieldBalanceNotifyEnabled:
return m.OldBalanceNotifyEnabled(ctx)
+ case user.FieldBalanceNotifyThresholdType:
+ return m.OldBalanceNotifyThresholdType(ctx)
case user.FieldBalanceNotifyThreshold:
return m.OldBalanceNotifyThreshold(ctx)
case user.FieldBalanceNotifyExtraEmails:
return m.OldBalanceNotifyExtraEmails(ctx)
+ case user.FieldTotalRecharged:
+ return m.OldTotalRecharged(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -29898,6 +30007,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetBalanceNotifyEnabled(v)
return nil
+ case user.FieldBalanceNotifyThresholdType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyThresholdType(v)
+ return nil
case user.FieldBalanceNotifyThreshold:
v, ok := value.(float64)
if !ok {
@@ -29912,6 +30028,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetBalanceNotifyExtraEmails(v)
return nil
+ case user.FieldTotalRecharged:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotalRecharged(v)
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -29929,6 +30052,9 @@ func (m *UserMutation) AddedFields() []string {
if m.addbalance_notify_threshold != nil {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
+ if m.addtotal_recharged != nil {
+ fields = append(fields, user.FieldTotalRecharged)
+ }
return fields
}
@@ -29943,6 +30069,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedConcurrency()
case user.FieldBalanceNotifyThreshold:
return m.AddedBalanceNotifyThreshold()
+ case user.FieldTotalRecharged:
+ return m.AddedTotalRecharged()
}
return nil, false
}
@@ -29973,6 +30101,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddBalanceNotifyThreshold(v)
return nil
+ case user.FieldTotalRecharged:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddTotalRecharged(v)
+ return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
@@ -30072,12 +30207,18 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldBalanceNotifyEnabled:
m.ResetBalanceNotifyEnabled()
return nil
+ case user.FieldBalanceNotifyThresholdType:
+ m.ResetBalanceNotifyThresholdType()
+ return nil
case user.FieldBalanceNotifyThreshold:
m.ResetBalanceNotifyThreshold()
return nil
case user.FieldBalanceNotifyExtraEmails:
m.ResetBalanceNotifyExtraEmails()
return nil
+ case user.FieldTotalRecharged:
+ m.ResetTotalRecharged()
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index a288f5d9..951b5f99 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -1297,10 +1297,18 @@ func init() {
userDescBalanceNotifyEnabled := userFields[11].Descriptor()
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
+ // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
+ userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
+ // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
+ user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
- userDescBalanceNotifyExtraEmails := userFields[13].Descriptor()
+ userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
+ // userDescTotalRecharged is the schema descriptor for total_recharged field.
+ userDescTotalRecharged := userFields[15].Descriptor()
+ // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
+ user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index bdaa4509..ef52e985 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -76,6 +76,8 @@ func (User) Fields() []ent.Field {
// 余额不足通知
field.Bool("balance_notify_enabled").
Default(true),
+ field.String("balance_notify_threshold_type").
+ Default("fixed"), // "fixed" | "percentage"
field.Float("balance_notify_threshold").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Optional().
@@ -83,6 +85,9 @@ func (User) Fields() []ent.Field {
field.String("balance_notify_extra_emails").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default("[]"),
+ field.Float("total_recharged").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
+ Default(0),
}
}
diff --git a/backend/ent/user.go b/backend/ent/user.go
index fc4ddb8f..9fa91f74 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -47,10 +47,14 @@ type User struct {
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
+ // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type,omitempty"`
// BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field.
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
// BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field.
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
+ // TotalRecharged holds the value of the "total_recharged" field.
+ TotalRecharged float64 `json:"total_recharged,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -192,11 +196,11 @@ func (*User) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled:
values[i] = new(sql.NullBool)
- case user.FieldBalance, user.FieldBalanceNotifyThreshold:
+ case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyExtraEmails:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
@@ -314,6 +318,12 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.BalanceNotifyEnabled = value.Bool
}
+ case user.FieldBalanceNotifyThresholdType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_threshold_type", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyThresholdType = value.String
+ }
case user.FieldBalanceNotifyThreshold:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i])
@@ -327,6 +337,12 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.BalanceNotifyExtraEmails = value.String
}
+ case user.FieldTotalRecharged:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field total_recharged", values[i])
+ } else if value.Valid {
+ _m.TotalRecharged = value.Float64
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -469,6 +485,9 @@ func (_m *User) String() string {
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")
+ builder.WriteString("balance_notify_threshold_type=")
+ builder.WriteString(_m.BalanceNotifyThresholdType)
+ builder.WriteString(", ")
if v := _m.BalanceNotifyThreshold; v != nil {
builder.WriteString("balance_notify_threshold=")
builder.WriteString(fmt.Sprintf("%v", *v))
@@ -476,6 +495,9 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("balance_notify_extra_emails=")
builder.WriteString(_m.BalanceNotifyExtraEmails)
+ builder.WriteString(", ")
+ builder.WriteString("total_recharged=")
+ builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index aff37013..d88a3a38 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -45,10 +45,14 @@ const (
FieldTotpEnabledAt = "totp_enabled_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
+ // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
+ FieldBalanceNotifyThresholdType = "balance_notify_threshold_type"
// FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database.
FieldBalanceNotifyThreshold = "balance_notify_threshold"
// FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database.
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
+ // FieldTotalRecharged holds the string denoting the total_recharged field in the database.
+ FieldTotalRecharged = "total_recharged"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -168,8 +172,10 @@ var Columns = []string{
FieldTotpEnabled,
FieldTotpEnabledAt,
FieldBalanceNotifyEnabled,
+ FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
FieldBalanceNotifyExtraEmails,
+ FieldTotalRecharged,
}
var (
@@ -228,8 +234,12 @@ var (
DefaultTotpEnabled bool
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
+ // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
+ DefaultBalanceNotifyThresholdType string
// DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field.
DefaultBalanceNotifyExtraEmails string
+ // DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
+ DefaultTotalRecharged float64
)
// OrderOption defines the ordering options for the User queries.
@@ -315,6 +325,11 @@ func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
}
+// ByBalanceNotifyThresholdType orders the results by the balance_notify_threshold_type field.
+func ByBalanceNotifyThresholdType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyThresholdType, opts...).ToFunc()
+}
+
// ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field.
func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc()
@@ -325,6 +340,11 @@ func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc()
}
+// ByTotalRecharged orders the results by the total_recharged field.
+func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index 11a0318f..2788aa7a 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -130,6 +130,11 @@ func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
}
+// BalanceNotifyThresholdType applies equality check predicate on the "balance_notify_threshold_type" field. It's identical to BalanceNotifyThresholdTypeEQ.
+func BalanceNotifyThresholdType(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
+}
+
// BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ.
func BalanceNotifyThreshold(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
@@ -140,6 +145,11 @@ func BalanceNotifyExtraEmails(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
}
+// TotalRecharged applies equality check predicate on the "total_recharged" field. It's identical to TotalRechargedEQ.
+func TotalRecharged(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -885,6 +895,71 @@ func BalanceNotifyEnabledNEQ(v bool) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v))
}
+// BalanceNotifyThresholdTypeEQ applies the EQ predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeNEQ applies the NEQ predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeIn applies the In predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyThresholdType, vs...))
+}
+
+// BalanceNotifyThresholdTypeNotIn applies the NotIn predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThresholdType, vs...))
+}
+
+// BalanceNotifyThresholdTypeGT applies the GT predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeGTE applies the GTE predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeLT applies the LT predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeLTE applies the LTE predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeContains applies the Contains predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeHasPrefix applies the HasPrefix predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeHasSuffix applies the HasSuffix predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeEqualFold applies the EqualFold predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeContainsFold applies the ContainsFold predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyThresholdType, v))
+}
+
// BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdEQ(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
@@ -1000,6 +1075,46 @@ func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v))
}
+// TotalRechargedEQ applies the EQ predicate on the "total_recharged" field.
+func TotalRechargedEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
+}
+
+// TotalRechargedNEQ applies the NEQ predicate on the "total_recharged" field.
+func TotalRechargedNEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotalRecharged, v))
+}
+
+// TotalRechargedIn applies the In predicate on the "total_recharged" field.
+func TotalRechargedIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldTotalRecharged, vs...))
+}
+
+// TotalRechargedNotIn applies the NotIn predicate on the "total_recharged" field.
+func TotalRechargedNotIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldTotalRecharged, vs...))
+}
+
+// TotalRechargedGT applies the GT predicate on the "total_recharged" field.
+func TotalRechargedGT(v float64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldTotalRecharged, v))
+}
+
+// TotalRechargedGTE applies the GTE predicate on the "total_recharged" field.
+func TotalRechargedGTE(v float64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldTotalRecharged, v))
+}
+
+// TotalRechargedLT applies the LT predicate on the "total_recharged" field.
+func TotalRechargedLT(v float64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldTotalRecharged, v))
+}
+
+// TotalRechargedLTE applies the LTE predicate on the "total_recharged" field.
+func TotalRechargedLTE(v float64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index 955fde72..fbc64f9c 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -225,6 +225,20 @@ func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate {
return _c
}
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_c *UserCreate) SetBalanceNotifyThresholdType(v string) *UserCreate {
+ _c.mutation.SetBalanceNotifyThresholdType(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyThresholdType(v *string) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyThresholdType(*v)
+ }
+ return _c
+}
+
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate {
_c.mutation.SetBalanceNotifyThreshold(v)
@@ -253,6 +267,20 @@ func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate
return _c
}
+// SetTotalRecharged sets the "total_recharged" field.
+func (_c *UserCreate) SetTotalRecharged(v float64) *UserCreate {
+ _c.mutation.SetTotalRecharged(v)
+ return _c
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
+ if v != nil {
+ _c.SetTotalRecharged(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -486,10 +514,18 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
}
+ if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
+ v := user.DefaultBalanceNotifyThresholdType
+ _c.mutation.SetBalanceNotifyThresholdType(v)
+ }
if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
v := user.DefaultBalanceNotifyExtraEmails
_c.mutation.SetBalanceNotifyExtraEmails(v)
}
+ if _, ok := _c.mutation.TotalRecharged(); !ok {
+ v := user.DefaultTotalRecharged
+ _c.mutation.SetTotalRecharged(v)
+ }
return nil
}
@@ -556,9 +592,15 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
+ if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
+ return &ValidationError{Name: "balance_notify_threshold_type", err: errors.New(`ent: missing required field "User.balance_notify_threshold_type"`)}
+ }
if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)}
}
+ if _, ok := _c.mutation.TotalRecharged(); !ok {
+ return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
+ }
return nil
}
@@ -646,6 +688,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
}
+ if value, ok := _c.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ _node.BalanceNotifyThresholdType = value
+ }
if value, ok := _c.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
_node.BalanceNotifyThreshold = &value
@@ -654,6 +700,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
_node.BalanceNotifyExtraEmails = value
}
+ if value, ok := _c.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ _node.TotalRecharged = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1068,6 +1118,18 @@ func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert {
return u
}
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsert) SetBalanceNotifyThresholdType(v string) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyThresholdType, v)
+ return u
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyThresholdType() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyThresholdType)
+ return u
+}
+
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert {
u.Set(user.FieldBalanceNotifyThreshold, v)
@@ -1104,6 +1166,24 @@ func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert {
return u
}
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsert) SetTotalRecharged(v float64) *UserUpsert {
+ u.Set(user.FieldTotalRecharged, v)
+ return u
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotalRecharged() *UserUpsert {
+ u.SetExcluded(user.FieldTotalRecharged)
+ return u
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
+ u.Add(user.FieldTotalRecharged, v)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1380,6 +1460,20 @@ func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne {
})
}
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsertOne) SetBalanceNotifyThresholdType(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThresholdType(v)
+ })
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyThresholdType() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThresholdType()
+ })
+}
+
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
@@ -1422,6 +1516,27 @@ func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne {
})
}
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsertOne) SetTotalRecharged(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotalRecharged(v)
+ })
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsertOne) AddTotalRecharged(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddTotalRecharged(v)
+ })
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotalRecharged()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1864,6 +1979,20 @@ func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk {
})
}
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsertBulk) SetBalanceNotifyThresholdType(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThresholdType(v)
+ })
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyThresholdType() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThresholdType()
+ })
+}
+
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
@@ -1906,6 +2035,27 @@ func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk {
})
}
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsertBulk) SetTotalRecharged(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotalRecharged(v)
+ })
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsertBulk) AddTotalRecharged(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddTotalRecharged(v)
+ })
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotalRecharged()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 823df0b6..6b355247 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -257,6 +257,20 @@ func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate {
return _u
}
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_u *UserUpdate) SetBalanceNotifyThresholdType(v string) *UserUpdate {
+ _u.mutation.SetBalanceNotifyThresholdType(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyThresholdType(*v)
+ }
+ return _u
+}
+
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate {
_u.mutation.ResetBalanceNotifyThreshold()
@@ -298,6 +312,27 @@ func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate
return _u
}
+// SetTotalRecharged sets the "total_recharged" field.
+func (_u *UserUpdate) SetTotalRecharged(v float64) *UserUpdate {
+ _u.mutation.ResetTotalRecharged()
+ _u.mutation.SetTotalRecharged(v)
+ return _u
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotalRecharged(v *float64) *UserUpdate {
+ if v != nil {
+ _u.SetTotalRecharged(*v)
+ }
+ return _u
+}
+
+// AddTotalRecharged adds value to the "total_recharged" field.
+func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
+ _u.mutation.AddTotalRecharged(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -804,6 +839,9 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
+ if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ }
if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
@@ -816,6 +854,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
}
+ if value, ok := _u.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedTotalRecharged(); ok {
+ _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1518,6 +1562,20 @@ func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne
return _u
}
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_u *UserUpdateOne) SetBalanceNotifyThresholdType(v string) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyThresholdType(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyThresholdType(*v)
+ }
+ return _u
+}
+
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne {
_u.mutation.ResetBalanceNotifyThreshold()
@@ -1559,6 +1617,27 @@ func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpd
return _u
}
+// SetTotalRecharged sets the "total_recharged" field.
+func (_u *UserUpdateOne) SetTotalRecharged(v float64) *UserUpdateOne {
+ _u.mutation.ResetTotalRecharged()
+ _u.mutation.SetTotalRecharged(v)
+ return _u
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotalRecharged(v *float64) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotalRecharged(*v)
+ }
+ return _u
+}
+
+// AddTotalRecharged adds value to the "total_recharged" field.
+func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
+ _u.mutation.AddTotalRecharged(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2095,6 +2174,9 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
+ if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ }
if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
@@ -2107,6 +2189,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
}
+ if value, ok := _u.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedTotalRecharged(); ok {
+ _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index e5e024c6..3d587a21 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -176,6 +176,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThresholdType: settings.BalanceLowNotifyThresholdType,
+ BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
+ AccountQuotaNotifyEmails: settings.AccountQuotaNotifyEmails,
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
@@ -305,6 +309,12 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
+ // Balance low notification
+ BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThresholdType *string `json:"balance_low_notify_threshold_type"`
+ BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
+ AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"`
+
// Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"`
@@ -882,6 +892,30 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
+ BalanceLowNotifyEnabled: func() bool {
+ if req.BalanceLowNotifyEnabled != nil {
+ return *req.BalanceLowNotifyEnabled
+ }
+ return previousSettings.BalanceLowNotifyEnabled
+ }(),
+ BalanceLowNotifyThresholdType: func() string {
+ if req.BalanceLowNotifyThresholdType != nil {
+ return *req.BalanceLowNotifyThresholdType
+ }
+ return previousSettings.BalanceLowNotifyThresholdType
+ }(),
+ BalanceLowNotifyThreshold: func() float64 {
+ if req.BalanceLowNotifyThreshold != nil {
+ return *req.BalanceLowNotifyThreshold
+ }
+ return previousSettings.BalanceLowNotifyThreshold
+ }(),
+ AccountQuotaNotifyEmails: func() []string {
+ if req.AccountQuotaNotifyEmails != nil {
+ return *req.AccountQuotaNotifyEmails
+ }
+ return previousSettings.AccountQuotaNotifyEmails
+ }(),
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -1028,6 +1062,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThresholdType: updatedSettings.BalanceLowNotifyThresholdType,
+ BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
+ AccountQuotaNotifyEmails: updatedSettings.AccountQuotaNotifyEmails,
PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index a465c7fb..147072c3 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -13,19 +13,21 @@ func UserFromServiceShallow(u *service.User) *User {
return nil
}
return &User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- AllowedGroups: u.AllowedGroups,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
- BalanceNotifyEnabled: u.BalanceNotifyEnabled,
- BalanceNotifyThreshold: u.BalanceNotifyThreshold,
- BalanceNotifyExtraEmails: u.BalanceNotifyExtraEmails,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ AllowedGroups: u.AllowedGroups,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
+ BalanceNotifyEnabled: u.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: u.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: u.BalanceNotifyExtraEmails,
+ TotalRecharged: u.TotalRecharged,
}
}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index e29f72da..8da7c6f2 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -150,9 +150,10 @@ type SystemSettings struct {
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
// Balance low notification
- BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
- BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
- AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"`
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThresholdType string `json:"balance_low_notify_threshold_type"`
+ BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
+ AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"`
}
type DefaultSubscriptionSetting struct {
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 18522868..425d3df9 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -19,9 +19,11 @@ type User struct {
UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
- BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
- BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
- BalanceNotifyExtraEmails []string `json:"balance_notify_extra_emails"`
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+ BalanceNotifyExtraEmails []string `json:"balance_notify_extra_emails"`
+ TotalRecharged float64 `json:"total_recharged"`
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 4fb72ce7..48528d55 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -33,9 +33,10 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
- Username *string `json:"username"`
- BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
- BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+ Username *string `json:"username"`
+ BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
+ BalanceNotifyThresholdType *string `json:"balance_notify_threshold_type"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// GetProfile handles getting user profile
@@ -100,9 +101,10 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
}
svcReq := service.UpdateProfileRequest{
- Username: req.Username,
- BalanceNotifyEnabled: req.BalanceNotifyEnabled,
- BalanceNotifyThreshold: req.BalanceNotifyThreshold,
+ Username: req.Username,
+ BalanceNotifyEnabled: req.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: req.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 752a5937..4ecab47a 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -641,22 +641,24 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
out := &service.User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Notes: u.Notes,
- PasswordHash: u.PasswordHash,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- TotpSecretEncrypted: u.TotpSecretEncrypted,
- TotpEnabled: u.TotpEnabled,
- TotpEnabledAt: u.TotpEnabledAt,
- BalanceNotifyEnabled: u.BalanceNotifyEnabled,
- BalanceNotifyThreshold: u.BalanceNotifyThreshold,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Notes: u.Notes,
+ PasswordHash: u.PasswordHash,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ TotpSecretEncrypted: u.TotpSecretEncrypted,
+ TotpEnabled: u.TotpEnabled,
+ TotpEnabledAt: u.TotpEnabledAt,
+ BalanceNotifyEnabled: u.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: u.BalanceNotifyThreshold,
+ TotalRecharged: u.TotalRecharged,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
}
// Parse extra emails JSON array
if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" {
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 2c544857..63168fb1 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -148,6 +148,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
+ SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails))
if userIn.BalanceNotifyThreshold == nil {
@@ -389,7 +390,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
- n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
+ update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
+ // Track cumulative recharge amount for percentage-based notifications
+ if amount > 0 {
+ update = update.AddTotalRecharged(amount)
+ }
+ n, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 8dd56b8f..7fbdd254 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -47,30 +47,21 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
if user == nil || s.emailService == nil || s.settingRepo == nil {
return
}
-
- // Check user-level switch
if !user.BalanceNotifyEnabled {
return
}
- // Check global switch
- globalEnabled, threshold := s.getBalanceNotifyConfig(ctx)
+ globalEnabled, globalThresholdType, globalThresholdValue := s.getBalanceNotifyConfig(ctx)
if !globalEnabled {
return
}
- // User custom threshold overrides system default
- if user.BalanceNotifyThreshold != nil {
- threshold = *user.BalanceNotifyThreshold
- }
-
+ threshold := s.resolveEffectiveThreshold(user, globalThresholdType, globalThresholdValue)
if threshold <= 0 {
return
}
newBalance := oldBalance - cost
-
- // Only notify on first crossing
if oldBalance >= threshold && newBalance < threshold {
siteName := s.getSiteName(ctx)
recipients := s.collectBalanceNotifyRecipients(user)
@@ -85,6 +76,30 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
}
}
+// resolveEffectiveThreshold computes the actual USD threshold based on type and user settings.
+func (s *BalanceNotifyService) resolveEffectiveThreshold(user *User, globalType string, globalValue float64) float64 {
+ // User-level override takes full precedence
+ if user.BalanceNotifyThreshold != nil {
+ thresholdType := user.BalanceNotifyThresholdType
+ if thresholdType == "" {
+ thresholdType = globalType
+ }
+ return computeThreshold(thresholdType, *user.BalanceNotifyThreshold, user.TotalRecharged)
+ }
+ return computeThreshold(globalType, globalValue, user.TotalRecharged)
+}
+
+// computeThreshold converts a threshold value to USD based on type.
+func computeThreshold(thresholdType string, value, totalRecharged float64) float64 {
+ if thresholdType == ThresholdTypePercentage {
+ if totalRecharged <= 0 {
+ return 0 // no recharge history → skip percentage check
+ }
+ return totalRecharged * value / 100
+ }
+ return value // fixed USD amount
+}
+
// quotaDim describes one quota dimension for notification checking.
type quotaDim struct {
name string
@@ -139,13 +154,21 @@ func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, account
}
// getBalanceNotifyConfig reads global balance notification settings.
-func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64) {
- keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold}
+func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, thresholdType string, threshold float64) {
+ keys := []string{
+ SettingKeyBalanceLowNotifyEnabled,
+ SettingKeyBalanceLowNotifyThresholdType,
+ SettingKeyBalanceLowNotifyThreshold,
+ }
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
- return false, 0
+ return false, ThresholdTypeFixed, 0
}
enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
+ thresholdType = settings[SettingKeyBalanceLowNotifyThresholdType]
+ if thresholdType == "" {
+ thresholdType = ThresholdTypeFixed
+ }
if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
threshold = f
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 2704e0d0..3de0e343 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -251,8 +251,13 @@ const (
SettingKeyEnableCCHSigning = "enable_cch_signing"
// Balance Low Notification
- SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
- SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD)
+ SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
+ SettingKeyBalanceLowNotifyThresholdType = "balance_low_notify_threshold_type" // "fixed" | "percentage"
+ SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD 或百分比)
+
+ // Threshold type constants
+ ThresholdTypeFixed = "fixed"
+ ThresholdTypePercentage = "percentage"
// Account Quota Notification
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组)
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index bc4f53ce..e2491cbc 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -597,6 +597,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
+ thresholdType := settings.BalanceLowNotifyThresholdType
+ if thresholdType == "" {
+ thresholdType = ThresholdTypeFixed
+ }
+ updates[SettingKeyBalanceLowNotifyThresholdType] = thresholdType
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails)
if err != nil {
@@ -1228,6 +1233,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
+ result.BalanceLowNotifyThresholdType = settings[SettingKeyBalanceLowNotifyThresholdType]
+ if result.BalanceLowNotifyThresholdType == "" {
+ result.BalanceLowNotifyThresholdType = ThresholdTypeFixed
+ }
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
result.BalanceLowNotifyThreshold = v
}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index debc2b19..b28d2247 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -108,8 +108,9 @@ type SystemSettings struct {
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
// Balance low notification
- BalanceLowNotifyEnabled bool
- BalanceLowNotifyThreshold float64
+ BalanceLowNotifyEnabled bool
+ BalanceLowNotifyThresholdType string // "fixed" (default) | "percentage"
+ BalanceLowNotifyThreshold float64
// Account quota notification
AccountQuotaNotifyEmails []string
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index b4818223..4ca31adc 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -31,9 +31,11 @@ type User struct {
TotpEnabledAt *time.Time // TOTP 启用时间
// 余额不足通知
- BalanceNotifyEnabled bool
- BalanceNotifyThreshold *float64
- BalanceNotifyExtraEmails []string
+ BalanceNotifyEnabled bool
+ BalanceNotifyThresholdType string // "fixed" (default) | "percentage"
+ BalanceNotifyThreshold *float64
+ BalanceNotifyExtraEmails []string
+ TotalRecharged float64
APIKeys []APIKey
Subscriptions []UserSubscription
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index e6b9a210..4669cb2b 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -62,11 +62,12 @@ type UserRepository interface {
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
- Email *string `json:"email"`
- Username *string `json:"username"`
- Concurrency *int `json:"concurrency"`
- BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
- BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+ Email *string `json:"email"`
+ Username *string `json:"username"`
+ Concurrency *int `json:"concurrency"`
+ BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
+ BalanceNotifyThresholdType *string `json:"balance_notify_threshold_type"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// ChangePasswordRequest 修改密码请求
@@ -143,6 +144,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if req.BalanceNotifyEnabled != nil {
user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled
}
+ if req.BalanceNotifyThresholdType != nil {
+ user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType
+ }
if req.BalanceNotifyThreshold != nil {
if *req.BalanceNotifyThreshold <= 0 {
user.BalanceNotifyThreshold = nil // clear to system default
diff --git a/backend/migrations/102_add_balance_notify_threshold_type.sql b/backend/migrations/102_add_balance_notify_threshold_type.sql
new file mode 100644
index 00000000..7ad70552
--- /dev/null
+++ b/backend/migrations/102_add_balance_notify_threshold_type.sql
@@ -0,0 +1,4 @@
+-- Add threshold type support (fixed / percentage) to balance notification
+ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_threshold_type VARCHAR(10) NOT NULL DEFAULT 'fixed';
+-- Track cumulative recharge amount for percentage threshold calculation
+ALTER TABLE users ADD COLUMN IF NOT EXISTS total_recharged DECIMAL(20,8) NOT NULL DEFAULT 0;
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 31284289..ec290be5 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -137,6 +137,7 @@ export interface SystemSettings {
// Balance & quota notification
balance_low_notify_enabled: boolean
+ balance_low_notify_threshold_type: 'fixed' | 'percentage'
balance_low_notify_threshold: number
account_quota_notify_emails: string[]
}
@@ -240,6 +241,7 @@ export interface UpdateSettingsRequest {
payment_cancel_rate_limit_window_mode?: string
// Balance & quota notification
balance_low_notify_enabled?: boolean
+ balance_low_notify_threshold_type?: 'fixed' | 'percentage'
balance_low_notify_threshold?: number
account_quota_notify_emails?: string[]
}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 8e10bf2a..880a81ee 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -4633,8 +4633,12 @@ export default {
title: 'Balance Low Notification',
description: 'Send email notification when user balance falls below threshold',
enabled: 'Enable Balance Low Notification',
+ thresholdType: 'Threshold Type',
+ typeFixed: 'Fixed Amount',
+ typePercentage: 'Percentage of Recharged',
threshold: 'Default Threshold',
thresholdHint: 'Used when user has not set a custom value',
+ percentageHint: 'Notify when balance falls below this percentage of total recharged amount',
thresholdPlaceholder: 'Enter amount',
},
quotaNotify: {
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 1b82f419..41d94e06 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -4797,8 +4797,12 @@ export default {
title: '余额不足提醒',
description: '当用户余额低于阈值时发送邮件提醒',
enabled: '启用余额不足提醒',
- threshold: '默认提醒阈值',
+ thresholdType: '阈值类型',
+ typeFixed: '固定金额',
+ typePercentage: '充值百分比',
+ threshold: '提醒阈值',
thresholdHint: '用户未自定义时使用此值',
+ percentageHint: '当余额低于累计充值额的此百分比时提醒',
thresholdPlaceholder: '输入金额',
},
quotaNotify: {
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 8ed77203..af84b67d 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2660,6 +2660,90 @@
+
+
+
+
+ {{ t('admin.settings.balanceNotify.title') }}
+
+
+ {{ t('admin.settings.balanceNotify.description') }}
+
+
+
+
+ {{ t('admin.settings.balanceNotify.enabled') }}
+
+
+
+
+
+
+
+
{{ t('admin.settings.balanceNotify.threshold') }}
+
+
+ {{ form.balance_low_notify_threshold_type === 'percentage' ? '%' : '$' }}
+
+
+
+
+ {{ form.balance_low_notify_threshold_type === 'percentage'
+ ? t('admin.settings.balanceNotify.percentageHint')
+ : t('admin.settings.balanceNotify.thresholdHint') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.quotaNotify.title') }}
+
+
+ {{ t('admin.settings.quotaNotify.description') }}
+
+
+
+
+
{{ t('admin.settings.quotaNotify.emails') }}
+
+
+
+
+
+
+
+
+ + {{ t('admin.settings.quotaNotify.addEmail') }}
+
+
+
{{ t('admin.settings.quotaNotify.emailsHint') }}
+
+
+
@@ -2939,7 +3023,12 @@ const form = reactive({
// Gateway forwarding behavior
enable_fingerprint_unification: true,
enable_metadata_passthrough: false,
- enable_cch_signing: false
+ enable_cch_signing: false,
+ // Balance & quota notification
+ balance_low_notify_enabled: false,
+ balance_low_notify_threshold_type: 'fixed' as 'fixed' | 'percentage',
+ balance_low_notify_threshold: 0,
+ account_quota_notify_emails: [] as string[]
})
// Proxies for web search emulation ProxySelector
@@ -3149,6 +3238,14 @@ function handleRegistrationEmailSuffixWhitelistPaste(event: ClipboardEvent) {
}
}
+// Quota notify email helpers
+const addQuotaNotifyEmail = () => {
+ if (!form.account_quota_notify_emails) {
+ form.account_quota_notify_emails = []
+ }
+ form.account_quota_notify_emails.push('')
+}
+
// LinuxDo OAuth redirect URL suggestion
const linuxdoRedirectUrlSuggestion = computed(() => {
if (typeof window === 'undefined') return ''
@@ -3488,6 +3585,11 @@ async function saveSettings() {
payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1,
payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit,
payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode,
+ // Balance & quota notification
+ balance_low_notify_enabled: form.balance_low_notify_enabled,
+ balance_low_notify_threshold_type: form.balance_low_notify_threshold_type,
+ balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0,
+ account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''),
}
const updated = await adminAPI.settings.updateSettings(payload)
From 9e33d0c4c079062f1d88fa7a2accd0bb2cb88b37 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 14:43:12 +0800
Subject: [PATCH 25/88] fix: address audit findings for websearch and balance
notification
- Fix GetByKeyForAuth not selecting balance notify fields (notifications
never triggered in gateway path)
- Fix provider-level ProxyURL never resolved: inject ProxyRepository into
SettingService, resolve proxy URLs when building Manager
- Fix admin manual balance adjustment not updating total_recharged
- Add threshold_type input validation (reject invalid values)
- Fix user threshold_type inheritance: custom threshold defaults to "fixed"
instead of inheriting global type (prevents $5 being treated as 5%)
- Add try-catch for clipboard.writeText (fails on non-HTTPS)
- Add SetTotalRecharged to user Update for admin balance operations
---
backend/cmd/server/wire_gen.go | 4 +-
backend/internal/repository/api_key_repo.go | 5 ++
backend/internal/repository/user_repo.go | 3 +-
backend/internal/server/http.go | 5 +-
backend/internal/service/admin_service.go | 6 ++
.../service/balance_notify_service.go | 4 +-
backend/internal/service/setting_service.go | 31 ++++++--
backend/internal/service/user_service.go | 4 +-
backend/internal/service/websearch_config.go | 70 ++++++++++++-------
backend/internal/service/wire.go | 5 +-
frontend/src/views/admin/SettingsView.vue | 8 ++-
11 files changed, 102 insertions(+), 43 deletions(-)
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 8c47b2bd..24ee02fd 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -50,7 +50,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
- settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
+ proxyRepository := repository.NewProxyRepository(client, db)
+ settingService := service.ProvideSettingService(settingRepository, groupRepository, proxyRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
@@ -100,7 +101,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
- proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 4ecab47a..11eac7a8 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -143,6 +143,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
+ user.FieldBalanceNotifyEnabled,
+ user.FieldBalanceNotifyThresholdType,
+ user.FieldBalanceNotifyThreshold,
+ user.FieldBalanceNotifyExtraEmails,
+ user.FieldTotalRecharged,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 63168fb1..1792ef8d 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -150,7 +150,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
- SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails))
+ SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
+ SetTotalRecharged(userIn.TotalRecharged)
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go
index ba45c31b..5165b059 100644
--- a/backend/internal/server/http.go
+++ b/backend/internal/server/http.go
@@ -59,7 +59,7 @@ func ProvideRouter(
}
// Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
- settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig) {
+ settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) {
if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 {
service.SetWebSearchManager(nil)
return
@@ -80,6 +80,9 @@ func ProvideRouter(
}
if p.ProxyID != nil {
pc.ProxyID = *p.ProxyID
+ if u, ok := proxyURLs[*p.ProxyID]; ok {
+ pc.ProxyURL = u
+ }
}
configs = append(configs, pc)
}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 97b42c24..a4e22b22 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -709,6 +709,12 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
+ // Track cumulative recharge for percentage-based balance notifications
+ balanceDelta := user.Balance - oldBalance
+ if balanceDelta > 0 {
+ user.TotalRecharged += balanceDelta
+ }
+
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 7fbdd254..e1f6bd8b 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -77,12 +77,12 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
}
// resolveEffectiveThreshold computes the actual USD threshold based on type and user settings.
+// When user sets a custom threshold, their type is used independently (defaults to "fixed" if unset).
func (s *BalanceNotifyService) resolveEffectiveThreshold(user *User, globalType string, globalValue float64) float64 {
- // User-level override takes full precedence
if user.BalanceNotifyThreshold != nil {
thresholdType := user.BalanceNotifyThresholdType
if thresholdType == "" {
- thresholdType = globalType
+ thresholdType = ThresholdTypeFixed // user custom value defaults to fixed, not inherited
}
return computeThreshold(thresholdType, *user.BalanceNotifyThreshold, user.TotalRecharged)
}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index e2491cbc..9b307426 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error)
}
+// WebSearchManagerBuilder creates a websearch.Manager from config (injected by infra layer).
+// proxyURLs maps proxy ID to resolved URL for provider-level proxy support.
+type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[int64]string)
+
// SettingService 系统设置服务
type SettingService struct {
- settingRepo SettingRepository
- defaultSubGroupReader DefaultSubscriptionGroupReader
- cfg *config.Config
- onUpdate func() // Callback when settings are updated (for cache invalidation)
- version string // Application version
+ settingRepo SettingRepository
+ defaultSubGroupReader DefaultSubscriptionGroupReader
+ proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
+ cfg *config.Config
+ onUpdate func() // Callback when settings are updated (for cache invalidation)
+ version string // Application version
+ webSearchManagerBuilder WebSearchManagerBuilder
}
// NewSettingService 创建系统设置服务实例
@@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri
s.defaultSubGroupReader = reader
}
+// SetProxyRepository injects a proxy repo for resolving websearch provider proxy URLs.
+func (s *SettingService) SetProxyRepository(repo ProxyRepository) {
+ s.proxyRepo = repo
+}
+
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
@@ -598,7 +609,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
thresholdType := settings.BalanceLowNotifyThresholdType
- if thresholdType == "" {
+ if thresholdType != ThresholdTypeFixed && thresholdType != ThresholdTypePercentage {
thresholdType = ThresholdTypeFixed
}
updates[SettingKeyBalanceLowNotifyThresholdType] = thresholdType
@@ -1231,6 +1242,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
+ // Web search emulation: quick enabled check from the JSON config
+ if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
+ var wsCfg WebSearchEmulationConfig
+ if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
+ result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
+ }
+ }
+
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
result.BalanceLowNotifyThresholdType = settings[SettingKeyBalanceLowNotifyThresholdType]
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 4669cb2b..26021a9b 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -145,7 +145,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled
}
if req.BalanceNotifyThresholdType != nil {
- user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType
+ if *req.BalanceNotifyThresholdType == ThresholdTypeFixed || *req.BalanceNotifyThresholdType == ThresholdTypePercentage {
+ user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType
+ }
}
if req.BalanceNotifyThreshold != nil {
if *req.BalanceNotifyThreshold <= 0 {
diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go
index bdfff3e4..346faf1f 100644
--- a/backend/internal/service/websearch_config.go
+++ b/backend/internal/service/websearch_config.go
@@ -10,7 +10,6 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
- "github.com/redis/go-redis/v9"
"golang.org/x/sync/singleflight"
)
@@ -85,8 +84,7 @@ const (
// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
if cached := webSearchEmulationCache.Load(); cached != nil {
- c := cached.(*cachedWebSearchEmulationConfig)
- if time.Now().UnixNano() < c.expiresAt {
+ if c, ok := cached.(*cachedWebSearchEmulationConfig); ok && time.Now().UnixNano() < c.expiresAt {
return c.config, nil
}
}
@@ -96,7 +94,10 @@ func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebS
if err != nil {
return &WebSearchEmulationConfig{}, err
}
- return result.(*WebSearchEmulationConfig), nil
+ if cfg, ok := result.(*WebSearchEmulationConfig); ok {
+ return cfg, nil
+ }
+ return &WebSearchEmulationConfig{}, nil
}
func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
@@ -154,7 +155,7 @@ func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *
})
// Hot-reload: rebuild the global Manager with new config
- s.RebuildWebSearchManager(ctx)
+ s.rebuildWebSearchManager(ctx)
return nil
}
@@ -196,34 +197,51 @@ func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
return cfg.Enabled && len(cfg.Providers) > 0
}
-// SetWebSearchRedisClient injects the Redis client used for quota tracking.
-// Call after construction, before first use. Triggers initial Manager build.
-func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) {
- s.webSearchRedis = redisClient
- s.RebuildWebSearchManager(ctx)
+// SetWebSearchManagerBuilder injects a callback that creates and wires a websearch.Manager.
+// The infra layer (main/wire) provides this builder, keeping redis out of the service layer.
+// Triggers initial build.
+func (s *SettingService) SetWebSearchManagerBuilder(ctx context.Context, builder WebSearchManagerBuilder) {
+ s.webSearchManagerBuilder = builder
+ s.rebuildWebSearchManager(ctx)
}
-// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager.
-// Called on startup and after SaveWebSearchEmulationConfig.
-func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
+// rebuildWebSearchManager reads the current config, resolves proxy URLs, and invokes the builder.
+func (s *SettingService) rebuildWebSearchManager(ctx context.Context) {
+ if s.webSearchManagerBuilder == nil {
+ return
+ }
cfg, err := s.GetWebSearchEmulationConfig(ctx)
- if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 {
+ if err != nil {
SetWebSearchManager(nil)
return
}
- providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
- for _, p := range cfg.Providers {
- providerConfigs = append(providerConfigs, websearch.ProviderConfig{
- Type: p.Type,
- APIKey: p.APIKey,
- Priority: p.Priority,
- QuotaLimit: p.QuotaLimit,
- QuotaRefreshInterval: p.QuotaRefreshInterval,
- ExpiresAt: p.ExpiresAt,
- })
+ proxyURLs := s.resolveProviderProxyURLs(ctx, cfg)
+ s.webSearchManagerBuilder(cfg, proxyURLs)
+}
+
+// resolveProviderProxyURLs collects proxy IDs from providers and resolves them to URLs.
+func (s *SettingService) resolveProviderProxyURLs(ctx context.Context, cfg *WebSearchEmulationConfig) map[int64]string {
+ if cfg == nil || s.proxyRepo == nil {
+ return nil
}
- SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis))
- slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
+ var ids []int64
+ for _, p := range cfg.Providers {
+ if p.ProxyID != nil && *p.ProxyID > 0 {
+ ids = append(ids, *p.ProxyID)
+ }
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ proxies, err := s.proxyRepo.ListByIDs(ctx, ids)
+ if err != nil {
+ return nil
+ }
+ result := make(map[int64]string, len(proxies))
+ for _, px := range proxies {
+ result[px.ID] = px.URL()
+ }
+ return result
}
// WebSearchTestResult holds the result of a search test.
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 2827f135..b4e33039 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -373,10 +373,11 @@ func ProvideBackupService(
return svc
}
-// ProvideSettingService wires SettingService with group reader for default subscription validation.
-func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
+// ProvideSettingService wires SettingService with group reader and proxy repo.
+func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, proxyRepo ProxyRepository, cfg *config.Config) *SettingService {
svc := NewSettingService(settingRepo, cfg)
svc.SetDefaultSubscriptionGroupReader(groupRepo)
+ svc.SetProxyRepository(proxyRepo)
return svc
}
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index af84b67d..50b532fb 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -3109,8 +3109,12 @@ async function copyApiKey(idx: number) {
appStore.showError(t('admin.settings.webSearchEmulation.apiKeyPlaceholder'))
return
}
- await navigator.clipboard.writeText(key)
- appStore.showSuccess(t('admin.settings.webSearchEmulation.copied'))
+ try {
+ await navigator.clipboard.writeText(key)
+ appStore.showSuccess(t('admin.settings.webSearchEmulation.copied'))
+ } catch {
+ appStore.showError(t('common.error'))
+ }
}
async function testWebSearchProvider() {
From cef22c70abb113eca5ff01262a5e5f31a63ce0fe Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 15:01:10 +0800
Subject: [PATCH 26/88] fix(notify): remove percentage threshold from balance
notification
Balance low notification only supports fixed USD amount threshold.
Percentage threshold is a quota concept, not applicable to balance.
Reverted threshold_type from admin settings, user profile, and all
backend/frontend layers. DB fields (balance_notify_threshold_type,
total_recharged) retained for potential future quota use.
---
.../internal/handler/admin/setting_handler.go | 15 ++----
backend/internal/handler/dto/settings.go | 7 ++-
backend/internal/handler/user_handler.go | 14 +++---
.../service/balance_notify_service.go | 46 ++++---------------
backend/internal/service/domain_constants.go | 9 +---
backend/internal/service/setting_service.go | 9 ----
backend/internal/service/settings_view.go | 5 +-
backend/internal/service/user_service.go | 16 ++-----
frontend/src/api/admin/settings.ts | 2 -
frontend/src/i18n/locales/en.ts | 4 --
frontend/src/i18n/locales/zh.ts | 6 +--
frontend/src/views/admin/SettingsView.vue | 44 +++---------------
12 files changed, 37 insertions(+), 140 deletions(-)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 3d587a21..49e7aeed 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -177,7 +177,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
- BalanceLowNotifyThresholdType: settings.BalanceLowNotifyThresholdType,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
AccountQuotaNotifyEmails: settings.AccountQuotaNotifyEmails,
PaymentEnabled: paymentCfg.Enabled,
@@ -310,10 +309,9 @@ type UpdateSettingsRequest struct {
EnableCCHSigning *bool `json:"enable_cch_signing"`
// Balance low notification
- BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
- BalanceLowNotifyThresholdType *string `json:"balance_low_notify_threshold_type"`
- BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
- AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"`
+ BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
+ AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"`
// Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"`
@@ -898,12 +896,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.BalanceLowNotifyEnabled
}(),
- BalanceLowNotifyThresholdType: func() string {
- if req.BalanceLowNotifyThresholdType != nil {
- return *req.BalanceLowNotifyThresholdType
- }
- return previousSettings.BalanceLowNotifyThresholdType
- }(),
BalanceLowNotifyThreshold: func() float64 {
if req.BalanceLowNotifyThreshold != nil {
return *req.BalanceLowNotifyThreshold
@@ -1063,7 +1055,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
- BalanceLowNotifyThresholdType: updatedSettings.BalanceLowNotifyThresholdType,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
AccountQuotaNotifyEmails: updatedSettings.AccountQuotaNotifyEmails,
PaymentEnabled: updatedPaymentCfg.Enabled,
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 8da7c6f2..e29f72da 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -150,10 +150,9 @@ type SystemSettings struct {
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
// Balance low notification
- BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
- BalanceLowNotifyThresholdType string `json:"balance_low_notify_threshold_type"`
- BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
- AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"`
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
+ AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"`
}
type DefaultSubscriptionSetting struct {
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 48528d55..4fb72ce7 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -33,10 +33,9 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
- Username *string `json:"username"`
- BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
- BalanceNotifyThresholdType *string `json:"balance_notify_threshold_type"`
- BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+ Username *string `json:"username"`
+ BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// GetProfile handles getting user profile
@@ -101,10 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
}
svcReq := service.UpdateProfileRequest{
- Username: req.Username,
- BalanceNotifyEnabled: req.BalanceNotifyEnabled,
- BalanceNotifyThresholdType: req.BalanceNotifyThresholdType,
- BalanceNotifyThreshold: req.BalanceNotifyThreshold,
+ Username: req.Username,
+ BalanceNotifyEnabled: req.BalanceNotifyEnabled,
+ BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index e1f6bd8b..0d7e4c09 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -51,12 +51,16 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
return
}
- globalEnabled, globalThresholdType, globalThresholdValue := s.getBalanceNotifyConfig(ctx)
+ globalEnabled, globalThreshold := s.getBalanceNotifyConfig(ctx)
if !globalEnabled {
return
}
- threshold := s.resolveEffectiveThreshold(user, globalThresholdType, globalThresholdValue)
+ // User custom threshold overrides system default
+ threshold := globalThreshold
+ if user.BalanceNotifyThreshold != nil {
+ threshold = *user.BalanceNotifyThreshold
+ }
if threshold <= 0 {
return
}
@@ -76,30 +80,6 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
}
}
-// resolveEffectiveThreshold computes the actual USD threshold based on type and user settings.
-// When user sets a custom threshold, their type is used independently (defaults to "fixed" if unset).
-func (s *BalanceNotifyService) resolveEffectiveThreshold(user *User, globalType string, globalValue float64) float64 {
- if user.BalanceNotifyThreshold != nil {
- thresholdType := user.BalanceNotifyThresholdType
- if thresholdType == "" {
- thresholdType = ThresholdTypeFixed // user custom value defaults to fixed, not inherited
- }
- return computeThreshold(thresholdType, *user.BalanceNotifyThreshold, user.TotalRecharged)
- }
- return computeThreshold(globalType, globalValue, user.TotalRecharged)
-}
-
-// computeThreshold converts a threshold value to USD based on type.
-func computeThreshold(thresholdType string, value, totalRecharged float64) float64 {
- if thresholdType == ThresholdTypePercentage {
- if totalRecharged <= 0 {
- return 0 // no recharge history → skip percentage check
- }
- return totalRecharged * value / 100
- }
- return value // fixed USD amount
-}
-
// quotaDim describes one quota dimension for notification checking.
type quotaDim struct {
name string
@@ -154,21 +134,13 @@ func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, account
}
// getBalanceNotifyConfig reads global balance notification settings.
-func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, thresholdType string, threshold float64) {
- keys := []string{
- SettingKeyBalanceLowNotifyEnabled,
- SettingKeyBalanceLowNotifyThresholdType,
- SettingKeyBalanceLowNotifyThreshold,
- }
+func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64) {
+ keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
- return false, ThresholdTypeFixed, 0
+ return false, 0
}
enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
- thresholdType = settings[SettingKeyBalanceLowNotifyThresholdType]
- if thresholdType == "" {
- thresholdType = ThresholdTypeFixed
- }
if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
threshold = f
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 3de0e343..2704e0d0 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -251,13 +251,8 @@ const (
SettingKeyEnableCCHSigning = "enable_cch_signing"
// Balance Low Notification
- SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
- SettingKeyBalanceLowNotifyThresholdType = "balance_low_notify_threshold_type" // "fixed" | "percentage"
- SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD 或百分比)
-
- // Threshold type constants
- ThresholdTypeFixed = "fixed"
- ThresholdTypePercentage = "percentage"
+ SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
+ SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD)
// Account Quota Notification
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组)
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 9b307426..f0cf750a 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -608,11 +608,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
- thresholdType := settings.BalanceLowNotifyThresholdType
- if thresholdType != ThresholdTypeFixed && thresholdType != ThresholdTypePercentage {
- thresholdType = ThresholdTypeFixed
- }
- updates[SettingKeyBalanceLowNotifyThresholdType] = thresholdType
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails)
if err != nil {
@@ -1252,10 +1247,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
- result.BalanceLowNotifyThresholdType = settings[SettingKeyBalanceLowNotifyThresholdType]
- if result.BalanceLowNotifyThresholdType == "" {
- result.BalanceLowNotifyThresholdType = ThresholdTypeFixed
- }
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
result.BalanceLowNotifyThreshold = v
}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index b28d2247..debc2b19 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -108,9 +108,8 @@ type SystemSettings struct {
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
// Balance low notification
- BalanceLowNotifyEnabled bool
- BalanceLowNotifyThresholdType string // "fixed" (default) | "percentage"
- BalanceLowNotifyThreshold float64
+ BalanceLowNotifyEnabled bool
+ BalanceLowNotifyThreshold float64
// Account quota notification
AccountQuotaNotifyEmails []string
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 26021a9b..e6b9a210 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -62,12 +62,11 @@ type UserRepository interface {
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
- Email *string `json:"email"`
- Username *string `json:"username"`
- Concurrency *int `json:"concurrency"`
- BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
- BalanceNotifyThresholdType *string `json:"balance_notify_threshold_type"`
- BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+ Email *string `json:"email"`
+ Username *string `json:"username"`
+ Concurrency *int `json:"concurrency"`
+ BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// ChangePasswordRequest 修改密码请求
@@ -144,11 +143,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if req.BalanceNotifyEnabled != nil {
user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled
}
- if req.BalanceNotifyThresholdType != nil {
- if *req.BalanceNotifyThresholdType == ThresholdTypeFixed || *req.BalanceNotifyThresholdType == ThresholdTypePercentage {
- user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType
- }
- }
if req.BalanceNotifyThreshold != nil {
if *req.BalanceNotifyThreshold <= 0 {
user.BalanceNotifyThreshold = nil // clear to system default
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index ec290be5..31284289 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -137,7 +137,6 @@ export interface SystemSettings {
// Balance & quota notification
balance_low_notify_enabled: boolean
- balance_low_notify_threshold_type: 'fixed' | 'percentage'
balance_low_notify_threshold: number
account_quota_notify_emails: string[]
}
@@ -241,7 +240,6 @@ export interface UpdateSettingsRequest {
payment_cancel_rate_limit_window_mode?: string
// Balance & quota notification
balance_low_notify_enabled?: boolean
- balance_low_notify_threshold_type?: 'fixed' | 'percentage'
balance_low_notify_threshold?: number
account_quota_notify_emails?: string[]
}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 880a81ee..8e10bf2a 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -4633,12 +4633,8 @@ export default {
title: 'Balance Low Notification',
description: 'Send email notification when user balance falls below threshold',
enabled: 'Enable Balance Low Notification',
- thresholdType: 'Threshold Type',
- typeFixed: 'Fixed Amount',
- typePercentage: 'Percentage of Recharged',
threshold: 'Default Threshold',
thresholdHint: 'Used when user has not set a custom value',
- percentageHint: 'Notify when balance falls below this percentage of total recharged amount',
thresholdPlaceholder: 'Enter amount',
},
quotaNotify: {
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 41d94e06..1b82f419 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -4797,12 +4797,8 @@ export default {
title: '余额不足提醒',
description: '当用户余额低于阈值时发送邮件提醒',
enabled: '启用余额不足提醒',
- thresholdType: '阈值类型',
- typeFixed: '固定金额',
- typePercentage: '充值百分比',
- threshold: '提醒阈值',
+ threshold: '默认提醒阈值',
thresholdHint: '用户未自定义时使用此值',
- percentageHint: '当余额低于累计充值额的此百分比时提醒',
thresholdPlaceholder: '输入金额',
},
quotaNotify: {
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 50b532fb..faab49fe 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2675,43 +2675,13 @@
{{ t('admin.settings.balanceNotify.enabled') }}
-
-
-
-
-
-
{{ t('admin.settings.balanceNotify.threshold') }}
-
-
- {{ form.balance_low_notify_threshold_type === 'percentage' ? '%' : '$' }}
-
-
-
-
- {{ form.balance_low_notify_threshold_type === 'percentage'
- ? t('admin.settings.balanceNotify.percentageHint')
- : t('admin.settings.balanceNotify.thresholdHint') }}
-
+
+
{{ t('admin.settings.balanceNotify.threshold') }}
+
+ $
+
+
{{ t('admin.settings.balanceNotify.thresholdHint') }}
@@ -3026,7 +2996,6 @@ const form = reactive({
enable_cch_signing: false,
// Balance & quota notification
balance_low_notify_enabled: false,
- balance_low_notify_threshold_type: 'fixed' as 'fixed' | 'percentage',
balance_low_notify_threshold: 0,
account_quota_notify_emails: [] as string[]
})
@@ -3591,7 +3560,6 @@ async function saveSettings() {
payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode,
// Balance & quota notification
balance_low_notify_enabled: form.balance_low_notify_enabled,
- balance_low_notify_threshold_type: form.balance_low_notify_threshold_type,
balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0,
account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''),
}
From 889b5b4f3bfd6f3f5422b65fea44419710a55c7f Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 15:59:45 +0800
Subject: [PATCH 27/88] fix(websearch): improve settings UI and hide config
when globally disabled
- API Key show/copy buttons moved inside input field (inline icons)
- Proxy selector and test button on same row to save vertical space
- Test opens a dialog modal instead of inline display
- Hide all websearch config in channels/accounts when global toggle is off
---
backend/cmd/server/VERSION | 2 +-
.../components/account/CreateAccountModal.vue | 10 +-
.../components/account/EditAccountModal.vue | 10 +-
frontend/src/views/admin/ChannelsView.vue | 86 ++++++++--
frontend/src/views/admin/SettingsView.vue | 155 ++++++++++--------
5 files changed, 177 insertions(+), 86 deletions(-)
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index e534f2aa..061bed0b 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.108.140
+0.1.110.11
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index d5d8ff89..e0dbce61 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -2325,9 +2325,9 @@
-
+
@@ -2998,6 +2998,12 @@ const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OF
const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
const webSearchEmulationEnabled = ref(false)
+const webSearchGlobalEnabled = ref(false)
+
+// Load web search global state once
+adminAPI.settings.getWebSearchEmulationConfig().then(cfg => {
+ webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
+}).catch(() => { webSearchGlobalEnabled.value = false })
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 54dae5c2..086575e6 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -1149,9 +1149,9 @@
-
+
@@ -1975,6 +1975,12 @@ const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OF
const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
const webSearchEmulationEnabled = ref(false)
+const webSearchGlobalEnabled = ref(false)
+
+// Load web search global state once
+adminAPI.settings.getWebSearchEmulationConfig().then(cfg => {
+ webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
+}).catch(() => { webSearchGlobalEnabled.value = false })
const editQuotaLimit = ref(null)
const editQuotaDailyLimit = ref(null)
const editQuotaWeeklyLimit = ref(null)
diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue
index a49e1694..639ace4a 100644
--- a/frontend/src/views/admin/ChannelsView.vue
+++ b/frontend/src/views/admin/ChannelsView.vue
@@ -306,6 +306,21 @@
+
+
+
+
+
+ {{ t('admin.channels.form.webSearchEmulation') }}
+
+
+ {{ t('admin.channels.form.webSearchEmulationHint') }}
+
+
+
+
+
+
@@ -560,6 +575,7 @@
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
+import { extractApiErrorMessage } from '@/utils/apiError'
import { adminAPI } from '@/api/admin'
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest, AccountStatsPricingRule } from '@/api/admin/channels'
import type { PricingFormEntry } from '@/components/admin/channel/types'
@@ -583,6 +599,18 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
const { t } = useI18n()
const appStore = useAppStore()
+// Web Search global enabled state (loaded once on mount)
+const webSearchGlobalEnabled = ref(false)
+async function loadWebSearchGlobalState() {
+ try {
+ const cfg = await adminAPI.settings.getWebSearchEmulationConfig()
+ webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
+ } catch (err: unknown) {
+ console.warn('Failed to load web search global state:', err)
+ webSearchGlobalEnabled.value = false
+ }
+}
+
// ── Platform Section type ──
interface PlatformSection {
platform: GroupPlatform
@@ -591,6 +619,7 @@ interface PlatformSection {
group_ids: number[]
model_mapping: Record
model_pricing: PricingFormEntry[]
+ web_search_emulation: boolean
}
// ── Table columns ──
@@ -709,7 +738,8 @@ function addPlatformSection(platform: GroupPlatform) {
collapsed: false,
group_ids: [],
model_mapping: {},
- model_pricing: []
+ model_pricing: [],
+ web_search_emulation: false,
})
}
@@ -901,10 +931,14 @@ function accountStatsRulesToAPI(): AccountStatsPricingRule[] {
}
// ── Form ↔ API conversion ──
-function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record> } {
+function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record>, features_config: Record } {
const group_ids: number[] = []
const model_pricing: ChannelModelPricing[] = []
const model_mapping: Record> = {}
+ // Preserve existing features_config fields not managed by the form
+ const featuresConfig: Record = editingChannel.value?.features_config
+ ? { ...editingChannel.value.features_config }
+ : {}
for (const section of form.platforms) {
if (!section.enabled) continue
@@ -933,7 +967,19 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
}
}
- return { group_ids, model_pricing, model_mapping }
+ // Collect web_search_emulation (only anthropic platform supports it)
+ const wsEmulation: Record = {}
+ for (const section of form.platforms) {
+ if (!section.enabled) continue
+ if (section.web_search_emulation && section.platform === 'anthropic') {
+ wsEmulation[section.platform] = true
+ }
+ }
+ if (Object.keys(wsEmulation).length > 0) {
+ featuresConfig.web_search_emulation = wsEmulation
+ }
+
+ return { group_ids, model_pricing, model_mapping, features_config: featuresConfig }
}
function apiToForm(channel: Channel): PlatformSection[] {
@@ -977,13 +1023,19 @@ function apiToForm(channel: Channel): PlatformSection[] {
intervals: apiIntervalsToForm(p.intervals || [])
} as PricingFormEntry))
+ // Read web_search_emulation from features_config
+ const fc = channel.features_config
+ const wsEmulation = fc?.web_search_emulation as Record | undefined
+ const webSearchEnabled = wsEmulation?.[platform] === true
+
sections.push({
platform,
enabled: true,
collapsed: false,
group_ids: groupIds,
model_mapping: { ...mapping },
- model_pricing: pricing
+ model_pricing: pricing,
+ web_search_emulation: webSearchEnabled,
})
}
@@ -1008,10 +1060,10 @@ async function loadChannels() {
if (ctrl.signal.aborted || abortController !== ctrl) return
channels.value = response.items || []
pagination.total = response.total
- } catch (error: any) {
- if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return
- appStore.showError(t('admin.channels.loadError', 'Failed to load channels'))
- console.error('Error loading channels:', error)
+ } catch (error: unknown) {
+ const e = error as { name?: string; code?: string }
+ if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return
+ appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels')))
} finally {
if (abortController === ctrl) {
loading.value = false
@@ -1210,7 +1262,7 @@ async function handleSubmit() {
}
}
- const { group_ids, model_pricing, model_mapping } = formToAPI()
+ const { group_ids, model_pricing, model_mapping, features_config } = formToAPI()
submitting.value = true
try {
@@ -1224,6 +1276,7 @@ async function handleSubmit() {
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
billing_model_source: form.billing_model_source,
restrict_models: form.restrict_models,
+ features_config,
apply_pricing_to_account_stats: form.apply_pricing_to_account_stats,
account_stats_pricing_rules: accountStatsRulesToAPI()
}
@@ -1238,6 +1291,7 @@ async function handleSubmit() {
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
billing_model_source: form.billing_model_source,
restrict_models: form.restrict_models,
+ features_config,
apply_pricing_to_account_stats: form.apply_pricing_to_account_stats,
account_stats_pricing_rules: accountStatsRulesToAPI()
}
@@ -1246,12 +1300,10 @@ async function handleSubmit() {
}
closeDialog()
loadChannels()
- } catch (error: any) {
- const msg = error.response?.data?.detail || (editingChannel.value
+ } catch (error: unknown) {
+ appStore.showError(extractApiErrorMessage(error, editingChannel.value
? t('admin.channels.updateError', 'Failed to update channel')
- : t('admin.channels.createError', 'Failed to create channel'))
- appStore.showError(msg)
- console.error('Error saving channel:', error)
+ : t('admin.channels.createError', 'Failed to create channel')))
} finally {
submitting.value = false
}
@@ -1289,9 +1341,8 @@ async function confirmDelete() {
showDeleteDialog.value = false
deletingChannel.value = null
loadChannels()
- } catch (error: any) {
- appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel'))
- console.error('Error deleting channel:', error)
+ } catch (error: unknown) {
+ appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel')))
}
}
@@ -1299,6 +1350,7 @@ async function confirmDelete() {
onMounted(() => {
loadChannels()
loadGroups()
+ loadWebSearchGlobalState()
})
onUnmounted(() => {
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index faab49fe..b5181ccc 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -1789,40 +1789,42 @@
-
+
{{ t('admin.settings.webSearchEmulation.apiKey') }}
-
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -1858,44 +1860,19 @@
{{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }}
-
-
-
{{ t('admin.settings.webSearchEmulation.proxy') }}
-
-
-
-
-
-
-
-
- {{ wsTestLoading ? t('admin.settings.webSearchEmulation.testing') : t('admin.settings.webSearchEmulation.test') }}
-
-
-
-
-
- {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }}
-
-
- {{ t('admin.settings.webSearchEmulation.testNoResults') }}
-
-
-
{{ r.title }}
-
{{ r.snippet && r.snippet.length > 120 ? r.snippet.slice(0, 120) + '...' : r.snippet }}
-
+
+
+
+
{{ t('admin.settings.webSearchEmulation.proxy') }}
+
+
+ {{ t('admin.settings.webSearchEmulation.test') }}
+
@@ -1903,6 +1880,50 @@
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.testResultTitle') }}
+
+
+
+
+ {{ wsTestLoading ? t('admin.settings.webSearchEmulation.testing') : t('admin.settings.webSearchEmulation.test') }}
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }}
+
+
+ {{ t('admin.settings.webSearchEmulation.testNoResults') }}
+
+
+
+
+
+ {{ t('common.close') }}
+
+
+
+
+
@@ -3016,6 +3037,12 @@ const apiKeyVisible = reactive>({})
const wsTestQuery = ref('')
const wsTestLoading = ref(false)
const wsTestResult = ref(null)
+const wsTestDialogOpen = ref(false)
+
+function openTestDialog() {
+ wsTestResult.value = null
+ wsTestDialogOpen.value = true
+}
function toggleProviderExpand(idx: number) {
expandedProviders[idx] = !expandedProviders[idx]
From eba289a7ff82430a1570d84a97c4bdf2d10e9775 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 17:49:58 +0800
Subject: [PATCH 28/88] feat(notify): add global toggles, percentage threshold,
and visibility control
- Add global toggle for account quota notification in admin settings
- Add percentage-based threshold type for per-account quota alerts
- Hide balance notify card on user profile when global toggle is off
- Expose balance_low_notify_enabled and account_quota_notify_enabled in PublicSettings
- Add threshold type (fixed/percentage) to QuotaNotifyToggle with $ / % switcher
---
backend/internal/service/account.go | 20 ++++++++
.../service/balance_notify_service.go | 50 ++++++++++++++-----
backend/internal/service/domain_constants.go | 3 +-
backend/internal/service/setting_service.go | 12 ++++-
backend/internal/service/settings_view.go | 6 ++-
frontend/src/api/admin/settings.ts | 2 +
.../components/account/EditAccountModal.vue | 24 +++++++++
.../src/components/account/QuotaLimitCard.vue | 15 ++++++
.../components/account/QuotaNotifyToggle.vue | 31 ++++++++++--
frontend/src/i18n/locales/en.ts | 3 +-
frontend/src/i18n/locales/zh.ts | 3 +-
frontend/src/stores/app.ts | 4 +-
frontend/src/types/index.ts | 2 +
frontend/src/views/admin/SettingsView.vue | 10 +++-
frontend/src/views/user/ProfileView.vue | 5 +-
15 files changed, 164 insertions(+), 26 deletions(-)
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 0b225dac..6e5a768f 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -1432,6 +1432,14 @@ func (a *Account) getExtraString(key string) string {
return ""
}
+// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal
+func (a *Account) getExtraStringDefault(key, defaultVal string) string {
+ if v := a.getExtraString(key); v != "" {
+ return v
+ }
+ return defaultVal
+}
+
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func (a *Account) getExtraInt(key string) int {
if a.Extra == nil {
@@ -1498,6 +1506,10 @@ func (a *Account) GetQuotaNotifyDailyThreshold() float64 {
return a.getExtraFloat64("quota_notify_daily_threshold")
}
+func (a *Account) GetQuotaNotifyDailyThresholdType() string {
+ return a.getExtraStringDefault("quota_notify_daily_threshold_type", "fixed")
+}
+
func (a *Account) GetQuotaNotifyWeeklyEnabled() bool {
return a.getExtraBool("quota_notify_weekly_enabled")
}
@@ -1506,6 +1518,10 @@ func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 {
return a.getExtraFloat64("quota_notify_weekly_threshold")
}
+func (a *Account) GetQuotaNotifyWeeklyThresholdType() string {
+ return a.getExtraStringDefault("quota_notify_weekly_threshold_type", "fixed")
+}
+
func (a *Account) GetQuotaNotifyTotalEnabled() bool {
return a.getExtraBool("quota_notify_total_enabled")
}
@@ -1514,6 +1530,10 @@ func (a *Account) GetQuotaNotifyTotalThreshold() float64 {
return a.getExtraFloat64("quota_notify_total_threshold")
}
+func (a *Account) GetQuotaNotifyTotalThresholdType() string {
+ return a.getExtraStringDefault("quota_notify_total_threshold_type", "fixed")
+}
+
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 0d7e4c09..65cec594 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -82,19 +82,29 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
// quotaDim describes one quota dimension for notification checking.
type quotaDim struct {
- name string
- enabled bool
- threshold float64
- oldUsed float64
- limit float64
+ name string
+ enabled bool
+ threshold float64
+ thresholdType string // "fixed" (default) or "percentage"
+ oldUsed float64
+ limit float64
+}
+
+// resolvedThreshold returns the effective threshold value.
+// For percentage type, it computes threshold = limit * percentage / 100.
+func (d quotaDim) resolvedThreshold() float64 {
+ if d.thresholdType == "percentage" && d.limit > 0 {
+ return d.limit * d.threshold / 100
+ }
+ return d.threshold
}
// buildQuotaDims returns the three quota dimensions for notification checking.
func buildQuotaDims(account *Account) []quotaDim {
return []quotaDim{
- {quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()},
- {quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()},
- {quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaUsed(), account.GetQuotaLimit()},
+ {quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()},
+ {quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()},
+ {quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), account.GetQuotaUsed(), account.GetQuotaLimit()},
}
}
@@ -104,6 +114,9 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
return
}
+ if !s.isAccountQuotaNotifyEnabled(ctx) {
+ return
+ }
adminEmails := s.getAccountQuotaNotifyEmails(ctx)
if len(adminEmails) == 0 {
return
@@ -114,22 +127,26 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte
if !dim.enabled || dim.threshold <= 0 {
continue
}
+ effectiveThreshold := dim.resolvedThreshold()
+ if effectiveThreshold <= 0 {
+ continue
+ }
newUsed := dim.oldUsed + cost
- if dim.oldUsed < dim.threshold && newUsed >= dim.threshold {
- s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, siteName)
+ if dim.oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
+ s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, effectiveThreshold, siteName)
}
}
}
// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery.
-func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountName string, dim quotaDim, newUsed float64, siteName string) {
+func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountName string, dim quotaDim, newUsed, effectiveThreshold float64, siteName string) {
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in quota notification", "recover", r)
}
}()
- s.sendQuotaAlertEmails(adminEmails, accountName, dim.name, newUsed, dim.limit, dim.threshold, siteName)
+ s.sendQuotaAlertEmails(adminEmails, accountName, dim.name, newUsed, dim.limit, effectiveThreshold, siteName)
}()
}
@@ -149,6 +166,15 @@ func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enab
return
}
+// isAccountQuotaNotifyEnabled checks the global account quota notification toggle.
+func (s *BalanceNotifyService) isAccountQuotaNotifyEnabled(ctx context.Context) bool {
+ val, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEnabled)
+ if err != nil {
+ return false
+ }
+ return val == "true"
+}
+
// getAccountQuotaNotifyEmails reads admin notification emails from settings.
func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails)
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 2704e0d0..f07ddfd4 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -255,7 +255,8 @@ const (
SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD)
// Account Quota Notification
- SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组)
+ SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关
+ SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组)
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index f0cf750a..abcae9c1 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -182,6 +182,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
+ SettingKeyBalanceLowNotifyEnabled,
+ SettingKeyAccountQuotaNotifyEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -249,6 +251,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
+ BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
+ AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
}, nil
}
@@ -302,6 +306,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
Version string `json:"version,omitempty"`
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
@@ -332,6 +338,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
Version: s.version,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
}, nil
}
@@ -609,6 +617,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
+ updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails)
if err != nil {
return fmt.Errorf("marshal account quota notify emails: %w", err)
@@ -1251,7 +1260,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.BalanceLowNotifyThreshold = v
}
- // Account quota notification emails
+ // Account quota notification
+ result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true"
if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" {
var emails []string
if err := json.Unmarshal([]byte(raw), &emails); err == nil {
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index debc2b19..b79b930a 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -112,7 +112,8 @@ type SystemSettings struct {
BalanceLowNotifyThreshold float64
// Account quota notification
- AccountQuotaNotifyEmails []string
+ AccountQuotaNotifyEnabled bool
+ AccountQuotaNotifyEmails []string
}
type DefaultSubscriptionSetting struct {
@@ -152,6 +153,9 @@ type PublicSettings struct {
OIDCOAuthEnabled bool
OIDCOAuthProviderName string
Version string
+
+ BalanceLowNotifyEnabled bool
+ AccountQuotaNotifyEnabled bool
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 31284289..5c5de2d1 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -138,6 +138,7 @@ export interface SystemSettings {
// Balance & quota notification
balance_low_notify_enabled: boolean
balance_low_notify_threshold: number
+ account_quota_notify_enabled: boolean
account_quota_notify_emails: string[]
}
@@ -241,6 +242,7 @@ export interface UpdateSettingsRequest {
// Balance & quota notification
balance_low_notify_enabled?: boolean
balance_low_notify_threshold?: number
+ account_quota_notify_enabled?: boolean
account_quota_notify_emails?: string[]
}
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 086575e6..abb9569e 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -1188,10 +1188,13 @@
:resetTimezone="editResetTimezone"
:quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled"
:quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold"
+ :quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType"
:quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled"
:quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold"
+ :quotaNotifyWeeklyThresholdType="editQuotaNotifyWeeklyThresholdType"
:quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled"
:quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold"
+ :quotaNotifyTotalThresholdType="editQuotaNotifyTotalThresholdType"
@update:totalLimit="editQuotaLimit = $event"
@update:dailyLimit="editQuotaDailyLimit = $event"
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
@@ -1203,10 +1206,13 @@
@update:resetTimezone="editResetTimezone = $event"
@update:quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled = $event"
@update:quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold = $event"
+ @update:quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType = $event"
@update:quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled = $event"
@update:quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold = $event"
+ @update:quotaNotifyWeeklyThresholdType="editQuotaNotifyWeeklyThresholdType = $event"
@update:quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled = $event"
@update:quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold = $event"
+ @update:quotaNotifyTotalThresholdType="editQuotaNotifyTotalThresholdType = $event"
/>
@@ -1232,10 +1238,13 @@
:resetTimezone="editResetTimezone"
:quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled"
:quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold"
+ :quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType"
:quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled"
:quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold"
+ :quotaNotifyWeeklyThresholdType="editQuotaNotifyWeeklyThresholdType"
:quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled"
:quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold"
+ :quotaNotifyTotalThresholdType="editQuotaNotifyTotalThresholdType"
@update:totalLimit="editQuotaLimit = $event"
@update:dailyLimit="editQuotaDailyLimit = $event"
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
@@ -1247,10 +1256,13 @@
@update:resetTimezone="editResetTimezone = $event"
@update:quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled = $event"
@update:quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold = $event"
+ @update:quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType = $event"
@update:quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled = $event"
@update:quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold = $event"
+ @update:quotaNotifyWeeklyThresholdType="editQuotaNotifyWeeklyThresholdType = $event"
@update:quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled = $event"
@update:quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold = $event"
+ @update:quotaNotifyTotalThresholdType="editQuotaNotifyTotalThresholdType = $event"
/>
@@ -1992,10 +2004,13 @@ const editWeeklyResetHour = ref(null)
const editResetTimezone = ref(null)
const editQuotaNotifyDailyEnabled = ref(null)
const editQuotaNotifyDailyThreshold = ref(null)
+const editQuotaNotifyDailyThresholdType = ref(null)
const editQuotaNotifyWeeklyEnabled = ref(null)
const editQuotaNotifyWeeklyThreshold = ref(null)
+const editQuotaNotifyWeeklyThresholdType = ref(null)
const editQuotaNotifyTotalEnabled = ref(null)
const editQuotaNotifyTotalThreshold = ref(null)
+const editQuotaNotifyTotalThresholdType = ref(null)
const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
@@ -2198,10 +2213,13 @@ const syncFormFromAccount = (newAccount: Account | null) => {
// Load quota notify config
editQuotaNotifyDailyEnabled.value = (extra?.quota_notify_daily_enabled as boolean) ?? null
editQuotaNotifyDailyThreshold.value = (extra?.quota_notify_daily_threshold as number) ?? null
+ editQuotaNotifyDailyThresholdType.value = (extra?.quota_notify_daily_threshold_type as string) ?? null
editQuotaNotifyWeeklyEnabled.value = (extra?.quota_notify_weekly_enabled as boolean) ?? null
editQuotaNotifyWeeklyThreshold.value = (extra?.quota_notify_weekly_threshold as number) ?? null
+ editQuotaNotifyWeeklyThresholdType.value = (extra?.quota_notify_weekly_threshold_type as string) ?? null
editQuotaNotifyTotalEnabled.value = (extra?.quota_notify_total_enabled as boolean) ?? null
editQuotaNotifyTotalThreshold.value = (extra?.quota_notify_total_threshold as number) ?? null
+ editQuotaNotifyTotalThresholdType.value = (extra?.quota_notify_total_threshold_type as string) ?? null
} else {
editQuotaLimit.value = null
editQuotaDailyLimit.value = null
@@ -3262,9 +3280,11 @@ const handleSubmit = async () => {
} else {
delete newExtra.quota_notify_daily_threshold
}
+ newExtra.quota_notify_daily_threshold_type = editQuotaNotifyDailyThresholdType.value || 'fixed'
} else {
delete newExtra.quota_notify_daily_enabled
delete newExtra.quota_notify_daily_threshold
+ delete newExtra.quota_notify_daily_threshold_type
}
if (editQuotaNotifyWeeklyEnabled.value) {
newExtra.quota_notify_weekly_enabled = true
@@ -3273,9 +3293,11 @@ const handleSubmit = async () => {
} else {
delete newExtra.quota_notify_weekly_threshold
}
+ newExtra.quota_notify_weekly_threshold_type = editQuotaNotifyWeeklyThresholdType.value || 'fixed'
} else {
delete newExtra.quota_notify_weekly_enabled
delete newExtra.quota_notify_weekly_threshold
+ delete newExtra.quota_notify_weekly_threshold_type
}
if (editQuotaNotifyTotalEnabled.value) {
newExtra.quota_notify_total_enabled = true
@@ -3284,9 +3306,11 @@ const handleSubmit = async () => {
} else {
delete newExtra.quota_notify_total_threshold
}
+ newExtra.quota_notify_total_threshold_type = editQuotaNotifyTotalThresholdType.value || 'fixed'
} else {
delete newExtra.quota_notify_total_enabled
delete newExtra.quota_notify_total_threshold
+ delete newExtra.quota_notify_total_threshold_type
}
updatePayload.extra = newExtra
}
diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue
index 7c3afd23..64bdb08a 100644
--- a/frontend/src/components/account/QuotaLimitCard.vue
+++ b/frontend/src/components/account/QuotaLimitCard.vue
@@ -17,17 +17,23 @@ const props = withDefaults(defineProps<{
resetTimezone: string | null
quotaNotifyDailyEnabled?: boolean | null
quotaNotifyDailyThreshold?: number | null
+ quotaNotifyDailyThresholdType?: string | null
quotaNotifyWeeklyEnabled?: boolean | null
quotaNotifyWeeklyThreshold?: number | null
+ quotaNotifyWeeklyThresholdType?: string | null
quotaNotifyTotalEnabled?: boolean | null
quotaNotifyTotalThreshold?: number | null
+ quotaNotifyTotalThresholdType?: string | null
}>(), {
quotaNotifyDailyEnabled: null,
quotaNotifyDailyThreshold: null,
+ quotaNotifyDailyThresholdType: null,
quotaNotifyWeeklyEnabled: null,
quotaNotifyWeeklyThreshold: null,
+ quotaNotifyWeeklyThresholdType: null,
quotaNotifyTotalEnabled: null,
quotaNotifyTotalThreshold: null,
+ quotaNotifyTotalThresholdType: null,
})
const emit = defineEmits<{
@@ -42,10 +48,13 @@ const emit = defineEmits<{
'update:resetTimezone': [value: string | null]
'update:quotaNotifyDailyEnabled': [value: boolean | null]
'update:quotaNotifyDailyThreshold': [value: number | null]
+ 'update:quotaNotifyDailyThresholdType': [value: string | null]
'update:quotaNotifyWeeklyEnabled': [value: boolean | null]
'update:quotaNotifyWeeklyThreshold': [value: number | null]
+ 'update:quotaNotifyWeeklyThresholdType': [value: string | null]
'update:quotaNotifyTotalEnabled': [value: boolean | null]
'update:quotaNotifyTotalThreshold': [value: number | null]
+ 'update:quotaNotifyTotalThresholdType': [value: string | null]
}>()
const enabled = computed(() =>
@@ -228,8 +237,10 @@ const onWeeklyModeChange = (e: Event) => {
v-if="dailyLimit && dailyLimit > 0"
:enabled="props.quotaNotifyDailyEnabled"
:threshold="props.quotaNotifyDailyThreshold"
+ :threshold-type="props.quotaNotifyDailyThresholdType"
@update:enabled="emit('update:quotaNotifyDailyEnabled', $event)"
@update:threshold="emit('update:quotaNotifyDailyThreshold', $event)"
+ @update:threshold-type="emit('update:quotaNotifyDailyThresholdType', $event)"
/>
@@ -292,8 +303,10 @@ const onWeeklyModeChange = (e: Event) => {
v-if="weeklyLimit && weeklyLimit > 0"
:enabled="props.quotaNotifyWeeklyEnabled"
:threshold="props.quotaNotifyWeeklyThreshold"
+ :threshold-type="props.quotaNotifyWeeklyThresholdType"
@update:enabled="emit('update:quotaNotifyWeeklyEnabled', $event)"
@update:threshold="emit('update:quotaNotifyWeeklyThreshold', $event)"
+ @update:threshold-type="emit('update:quotaNotifyWeeklyThresholdType', $event)"
/>
@@ -330,8 +343,10 @@ const onWeeklyModeChange = (e: Event) => {
v-if="totalLimit && totalLimit > 0"
:enabled="props.quotaNotifyTotalEnabled"
:threshold="props.quotaNotifyTotalThreshold"
+ :threshold-type="props.quotaNotifyTotalThresholdType"
@update:enabled="emit('update:quotaNotifyTotalEnabled', $event)"
@update:threshold="emit('update:quotaNotifyTotalThreshold', $event)"
+ @update:threshold-type="emit('update:quotaNotifyTotalThresholdType', $event)"
/>
diff --git a/frontend/src/components/account/QuotaNotifyToggle.vue b/frontend/src/components/account/QuotaNotifyToggle.vue
index 4634f5b1..b1c22fe2 100644
--- a/frontend/src/components/account/QuotaNotifyToggle.vue
+++ b/frontend/src/components/account/QuotaNotifyToggle.vue
@@ -6,12 +6,18 @@ const { t } = useI18n()
defineProps<{
enabled: boolean | null
threshold: number | null
+ thresholdType: string | null // "fixed" (default) or "percentage"
}>()
const emit = defineEmits<{
'update:enabled': [value: boolean | null]
'update:threshold': [value: number | null]
+ 'update:thresholdType': [value: string | null]
}>()
+
+function toggleType(current: string | null) {
+ emit('update:thresholdType', current === 'percentage' ? 'fixed' : 'percentage')
+}
@@ -32,15 +38,32 @@ const emit = defineEmits<{
]"
/>
-
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 8e10bf2a..6688c1b6 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -2257,7 +2257,7 @@ export default {
alert: 'Alert Threshold',
enabled: 'Enable Alert',
threshold: 'Alert Amount',
- thresholdPlaceholder: 'Enter alert amount',
+ thresholdPlaceholder: 'Enter percentage',
},
testConnection: 'Test Connection',
reAuthorize: 'Re-Authorize',
@@ -4640,6 +4640,7 @@ export default {
quotaNotify: {
title: 'Account Quota Notification',
description: 'Notify admins when account quota usage reaches alert threshold',
+ enabled: 'Enable Account Quota Notification',
emails: 'Notification Emails',
emailsHint: 'Leave empty to disable notifications',
addEmail: 'Add Email',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 1b82f419..70d20cc0 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2255,7 +2255,7 @@ export default {
alert: '告警阈值',
enabled: '启用告警',
threshold: '告警金额',
- thresholdPlaceholder: '输入告警金额',
+ thresholdPlaceholder: '输入百分比',
},
testConnection: '测试连接',
reAuthorize: '重新授权',
@@ -4804,6 +4804,7 @@ export default {
quotaNotify: {
title: '账号限额通知',
description: '当账号配额用量达到告警阈值时通知管理员',
+ enabled: '启用账号限额通知',
emails: '通知邮箱',
emailsHint: '留空则不发送通知',
addEmail: '添加邮箱',
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index 09e73621..b69c3648 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -339,7 +339,9 @@ export const useAppStore = defineStore('app', () => {
oidc_oauth_enabled: false,
oidc_oauth_provider_name: 'OIDC',
backend_mode_enabled: false,
- version: siteVersion.value
+ version: siteVersion.value,
+ balance_low_notify_enabled: false,
+ account_quota_notify_enabled: false,
}
}
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index e74f6e61..c6c74354 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -117,6 +117,8 @@ export interface PublicSettings {
oidc_oauth_provider_name: string
backend_mode_enabled: boolean
version: string
+ balance_low_notify_enabled: boolean
+ account_quota_notify_enabled: boolean
}
export interface AuthResponse {
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index b5181ccc..222d1dd7 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -2718,11 +2718,15 @@
-
+
+ {{ t('admin.settings.quotaNotify.enabled') }}
+
+
+
{{ t('admin.settings.quotaNotify.emails') }}
-
+
@@ -3018,6 +3022,7 @@ const form = reactive({
// Balance & quota notification
balance_low_notify_enabled: false,
balance_low_notify_threshold: 0,
+ account_quota_notify_enabled: false,
account_quota_notify_emails: [] as string[]
})
@@ -3588,6 +3593,7 @@ async function saveSettings() {
// Balance & quota notification
balance_low_notify_enabled: form.balance_low_notify_enabled,
balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0,
+ account_quota_notify_enabled: form.account_quota_notify_enabled,
account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''),
}
diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue
index 5534e1d6..f801e20d 100644
--- a/frontend/src/views/user/ProfileView.vue
+++ b/frontend/src/views/user/ProfileView.vue
@@ -15,7 +15,7 @@
authStore.user)
const contactInfo = ref('')
+const balanceLowNotifyEnabled = ref(false)
const WalletIcon = { render: () => h('svg', { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, [h('path', { d: 'M21 12a2.25 2.25 0 00-2.25-2.25H15a3 3 0 11-6 0H5.25A2.25 2.25 0 003 12' })]) }
const BoltIcon = { render: () => h('svg', { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, [h('path', { d: 'm3.75 13.5 10.5-11.25L12 10.5h8.25L9.75 21.75 12 13.5H3.75z' })]) }
const CalendarIcon = { render: () => h('svg', { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, [h('path', { d: 'M6.75 3v2.25M17.25 3v2.25' })]) }
-onMounted(async () => { try { const s = await authAPI.getPublicSettings(); contactInfo.value = s.contact_info || '' } catch (error) { console.error('Failed to load contact info:', error) } })
+onMounted(async () => { try { const s = await authAPI.getPublicSettings(); contactInfo.value = s.contact_info || ''; balanceLowNotifyEnabled.value = s.balance_low_notify_enabled ?? false } catch (error) { console.error('Failed to load contact info:', error) } })
const formatCurrency = (v: number) => `$${v.toFixed(2)}`
\ No newline at end of file
From 4e96a6faece261ca60905a01c751f087fd299976 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 18:11:47 +0800
Subject: [PATCH 29/88] fix: address audit findings for notify, websearch and
security
- Fix GetByKeyForAuth missing user.FieldEmail and user.FieldUsername (notifications sent to empty address)
- Guard against empty email in collectBalanceNotifyRecipients
- Remove non-atomic TotalRecharged read-modify-write in admin balance adjustment
- HTML-escape userName/siteName/accountName in notification email templates
- Fix timer leak in ProfileBalanceNotifyCard (add onUnmounted cleanup)
- Add warning log on websearch proxy URL resolution failure
---
backend/internal/repository/api_key_repo.go | 2 ++
backend/internal/service/admin_service.go | 6 ------
backend/internal/service/balance_notify_service.go | 10 +++++++---
backend/internal/service/websearch_config.go | 1 +
.../user/profile/ProfileBalanceNotifyCard.vue | 6 +++++-
5 files changed, 15 insertions(+), 10 deletions(-)
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 11eac7a8..d1b42750 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -139,6 +139,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
+ user.FieldEmail,
+ user.FieldUsername,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index a4e22b22..97b42c24 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -709,12 +709,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance)
}
- // Track cumulative recharge for percentage-based balance notifications
- balanceDelta := user.Balance - oldBalance
- if balanceDelta > 0 {
- user.TotalRecharged += balanceDelta
- }
-
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 65cec594..68202958 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "html"
"log/slog"
"strconv"
"strings"
@@ -195,7 +196,10 @@ func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
// collectBalanceNotifyRecipients collects all email recipients for balance notifications.
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
- recipients := []string{user.Email}
+ var recipients []string
+ if user.Email != "" {
+ recipients = append(recipients, user.Email)
+ }
for _, extra := range user.BalanceNotifyExtraEmails {
email := strings.TrimSpace(extra)
if email != "" && !strings.EqualFold(email, user.Email) {
@@ -224,7 +228,7 @@ func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userNam
displayName = userEmail
}
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", siteName)
- body := s.buildBalanceLowEmailBody(displayName, balance, threshold, siteName)
+ body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName))
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
}
@@ -236,7 +240,7 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
}
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", siteName, accountName)
- body := s.buildQuotaAlertEmailBody(accountName, dimLabel, used, limit, threshold, siteName)
+ body := s.buildQuotaAlertEmailBody(html.EscapeString(accountName), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName))
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension)
}
diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go
index 346faf1f..99e40275 100644
--- a/backend/internal/service/websearch_config.go
+++ b/backend/internal/service/websearch_config.go
@@ -235,6 +235,7 @@ func (s *SettingService) resolveProviderProxyURLs(ctx context.Context, cfg *WebS
}
proxies, err := s.proxyRepo.ListByIDs(ctx, ids)
if err != nil {
+ slog.Warn("websearch: failed to resolve proxy URLs", "error", err)
return nil
}
result := make(map[int64]string, len(proxies))
diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
index 130d82b5..758704e0 100644
--- a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
+++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
@@ -93,7 +93,7 @@
\ No newline at end of file
From 422807514c9f147c12f09b86decb9cbbc9312627 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 20:40:31 +0800
Subject: [PATCH 35/88] fix(notify): add duplicate email check message and
improve extra email UX
---
.../user/profile/ProfileBalanceNotifyCard.vue | 12 +++++++-----
frontend/src/i18n/locales/en.ts | 1 +
frontend/src/i18n/locales/zh.ts | 1 +
3 files changed, 9 insertions(+), 5 deletions(-)
diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
index cfe1b332..797616bb 100644
--- a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
+++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
@@ -49,14 +49,16 @@
-
+
{{ email }}
-
- {{ t('profile.balanceNotify.removeEmail') }}
-
+
+
+ {{ t('profile.balanceNotify.removeEmail') }}
+
+
@@ -185,7 +187,7 @@ function addPendingEmail() {
const email = newEmail.value.trim()
if (!email) return
if (email === props.userEmail || extraEmails.value.includes(email) || pendingEmails.value.some(p => p.email === email)) {
- appStore.showError(t('common.error'))
+ appStore.showError(t('profile.balanceNotify.emailDuplicate'))
return
}
pendingEmails.value.push({ email, codeSent: false, code: '', sending: false, verifying: false, countdown: 0, timer: null })
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 1b4cb9ea..ff77837e 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -929,6 +929,7 @@ export default {
verifySuccess: 'Email added successfully',
removeEmail: 'Remove',
removeSuccess: 'Email removed',
+ emailDuplicate: 'This email already exists',
}
},
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index ca091dd5..9eabb465 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -933,6 +933,7 @@ export default {
verifySuccess: '邮箱添加成功',
removeEmail: '移除',
removeSuccess: '邮箱已移除',
+ emailDuplicate: '该邮箱已存在',
}
},
From 61aa197b0bb3905e5e5b582eeb25c25995670f96 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 20:45:58 +0800
Subject: [PATCH 36/88] fix(notify): add explicit save button for balance
threshold
Replace blur-based auto-save with an explicit Save button so users
know when their threshold is persisted. Shows success toast on save.
---
.../user/profile/ProfileBalanceNotifyCard.vue | 15 +++++++++++++--
1 file changed, 13 insertions(+), 2 deletions(-)
diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
index 797616bb..589cc9e8 100644
--- a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
+++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
@@ -19,7 +19,7 @@
-
+
{{ t('profile.balanceNotify.threshold') }}
@@ -34,8 +34,14 @@
step="0.01"
class="input flex-1"
:placeholder="systemDefaultThreshold > 0 ? `${t('profile.balanceNotify.systemDefault')} $${systemDefaultThreshold}` : t('profile.balanceNotify.thresholdPlaceholder')"
- @blur="handleThresholdUpdate"
/>
+
+ {{ savingThreshold ? t('common.saving') : t('common.save') }}
+
@@ -152,6 +158,7 @@ const customThreshold = ref(props.threshold)
const extraEmails = ref([...props.extraEmails])
const pendingEmails = ref([])
const newEmail = ref('')
+const savingThreshold = ref(false)
watch(() => props.enabled, (val) => { notifyEnabled.value = val })
watch(() => props.threshold, (val) => { customThreshold.value = val })
@@ -174,12 +181,16 @@ const handleToggle = async () => {
}
const handleThresholdUpdate = async () => {
+ savingThreshold.value = true
try {
const threshold = customThreshold.value && customThreshold.value > 0 ? customThreshold.value : 0
const updated = await userAPI.updateProfile({ balance_notify_threshold: threshold })
authStore.user = updated
+ appStore.showSuccess(t('common.saved'))
} catch (err: unknown) {
appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ } finally {
+ savingThreshold.value = false
}
}
From 915b7a4a56aa859aa9de041d01c09ccbd988c78f Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 13 Apr 2026 00:52:42 +0800
Subject: [PATCH 37/88] feat(notify): convert email lists to NotifyEmailEntry
struct with toggle support
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Change balance_notify_extra_emails and account_quota_notify_emails
from []string to []NotifyEmailEntry{email, disabled, verified}
- Add per-email enable/disable toggle for both user and admin notifications
- Add PUT /user/notify-email/toggle API endpoint
- Fix critical bug: API key auth cache snapshot missing balance notify
fields (Email, Username, BalanceNotifyEnabled, etc.), causing
notifications to never fire on cached request paths
- Bump cache snapshot version 3→4 to invalidate stale entries
- Add SQL migration 104 to convert old format data
- Backward compatible: parseNotifyEmails auto-detects old/new format
- User balance notify: max 3 emails (primary + 2 extra)
- Admin quota notify: unlimited emails, each with toggle
---
.../internal/handler/admin/setting_handler.go | 10 +-
backend/internal/handler/dto/mappers.go | 2 +-
.../handler/dto/notify_email_entry.go | 43 +++++++
backend/internal/handler/dto/settings.go | 2 +-
backend/internal/handler/dto/types.go | 2 +-
backend/internal/handler/user_handler.go | 36 ++++++
backend/internal/repository/api_key_repo.go | 8 +-
backend/internal/repository/user_repo.go | 14 +--
backend/internal/server/routes/user.go | 1 +
.../internal/service/api_key_auth_cache.go | 8 ++
.../service/api_key_auth_cache_impl.go | 34 ++++--
.../service/balance_notify_service.go | 63 +++++++++--
.../internal/service/notify_email_entry.go | 107 ++++++++++++++++++
backend/internal/service/setting_service.go | 7 +-
backend/internal/service/settings_view.go | 2 +-
backend/internal/service/user.go | 2 +-
backend/internal/service/user_service.go | 42 +++++--
.../104_migrate_notify_emails_to_struct.sql | 35 ++++++
frontend/src/api/admin/settings.ts | 6 +-
frontend/src/api/user.ts | 17 ++-
.../user/profile/ProfileBalanceNotifyCard.vue | 72 ++++++++----
frontend/src/i18n/locales/en.ts | 2 +
frontend/src/i18n/locales/zh.ts | 2 +
frontend/src/types/index.ts | 12 +-
frontend/src/views/admin/SettingsView.vue | 14 ++-
25 files changed, 448 insertions(+), 95 deletions(-)
create mode 100644 backend/internal/handler/dto/notify_email_entry.go
create mode 100644 backend/internal/service/notify_email_entry.go
create mode 100644 backend/migrations/104_migrate_notify_emails_to_struct.sql
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 49e7aeed..fe46e821 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -178,7 +178,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
- AccountQuotaNotifyEmails: settings.AccountQuotaNotifyEmails,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
@@ -311,7 +311,7 @@ type UpdateSettingsRequest struct {
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
- AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"`
+ AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
// Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"`
@@ -902,9 +902,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.BalanceLowNotifyThreshold
}(),
- AccountQuotaNotifyEmails: func() []string {
+ AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry {
if req.AccountQuotaNotifyEmails != nil {
- return *req.AccountQuotaNotifyEmails
+ return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails)
}
return previousSettings.AccountQuotaNotifyEmails
}(),
@@ -1056,7 +1056,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableCCHSigning: updatedSettings.EnableCCHSigning,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
- AccountQuotaNotifyEmails: updatedSettings.AccountQuotaNotifyEmails,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 147072c3..a7a93d07 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -26,7 +26,7 @@ func UserFromServiceShallow(u *service.User) *User {
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
- BalanceNotifyExtraEmails: u.BalanceNotifyExtraEmails,
+ BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
TotalRecharged: u.TotalRecharged,
}
}
diff --git a/backend/internal/handler/dto/notify_email_entry.go b/backend/internal/handler/dto/notify_email_entry.go
new file mode 100644
index 00000000..180f8b25
--- /dev/null
+++ b/backend/internal/handler/dto/notify_email_entry.go
@@ -0,0 +1,43 @@
+package dto
+
+import "github.com/Wei-Shaw/sub2api/internal/service"
+
+// NotifyEmailEntry represents a notification email with enable/disable and verification state.
+// Email="" is a placeholder for the "primary email" (user's registration email or first admin email).
+type NotifyEmailEntry struct {
+ Email string `json:"email"`
+ Disabled bool `json:"disabled"`
+ Verified bool `json:"verified"`
+}
+
+// NotifyEmailEntriesFromService converts service entries to DTO entries.
+func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry {
+ if entries == nil {
+ return nil
+ }
+ result := make([]NotifyEmailEntry, len(entries))
+ for i, e := range entries {
+ result[i] = NotifyEmailEntry{
+ Email: e.Email,
+ Disabled: e.Disabled,
+ Verified: e.Verified,
+ }
+ }
+ return result
+}
+
+// NotifyEmailEntriesToService converts DTO entries to service entries.
+func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry {
+ if entries == nil {
+ return nil
+ }
+ result := make([]service.NotifyEmailEntry, len(entries))
+ for i, e := range entries {
+ result[i] = service.NotifyEmailEntry{
+ Email: e.Email,
+ Disabled: e.Disabled,
+ Verified: e.Verified,
+ }
+ }
+ return result
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 9c2ff263..545458c8 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -152,7 +152,7 @@ type SystemSettings struct {
// Balance low notification
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
- AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"`
+ AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
}
type DefaultSubscriptionSetting struct {
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 425d3df9..afb782b0 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -22,7 +22,7 @@ type User struct {
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
- BalanceNotifyExtraEmails []string `json:"balance_notify_extra_emails"`
+ BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
TotalRecharged float64 `json:"total_recharged"`
APIKeys []APIKey `json:"api_keys,omitempty"`
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 4fb72ce7..9e0a243a 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -214,3 +214,39 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
response.Success(c, dto.UserFromService(updatedUser))
}
+
+// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
+type ToggleNotifyEmailRequest struct {
+ Email string `json:"email"` // empty string for primary email placeholder
+ Disabled bool `json:"disabled"`
+}
+
+// ToggleNotifyEmail toggles the disabled state of a notification email
+// PUT /api/v1/user/notify-email/toggle
+func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req ToggleNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, dto.UserFromService(updatedUser))
+}
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index d1b42750..38ea9bde 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -3,7 +3,6 @@ package repository
import (
"context"
"database/sql"
- "encoding/json"
"fmt"
"strings"
"time"
@@ -667,12 +666,9 @@ func userEntityToService(u *dbent.User) *service.User {
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
- // Parse extra emails JSON array
+ // Parse extra emails JSON (supports both old []string and new []NotifyEmailEntry format)
if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" {
- var emails []string
- if err := json.Unmarshal([]byte(u.BalanceNotifyExtraEmails), &emails); err == nil {
- out.BalanceNotifyExtraEmails = emails
- }
+ out.BalanceNotifyExtraEmails = service.ParseNotifyEmails(u.BalanceNotifyExtraEmails)
}
return out
}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 1792ef8d..913e1c40 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -3,7 +3,6 @@ package repository
import (
"context"
"database/sql"
- "encoding/json"
"errors"
"fmt"
"sort"
@@ -563,16 +562,9 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.UpdatedAt = src.UpdatedAt
}
-// marshalExtraEmails serializes a string slice to JSON for storage.
-func marshalExtraEmails(emails []string) string {
- if len(emails) == 0 {
- return "[]"
- }
- data, err := json.Marshal(emails)
- if err != nil {
- return "[]"
- }
- return string(data)
+// marshalExtraEmails serializes notify email entries to JSON for storage.
+func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
+ return service.MarshalNotifyEmails(entries)
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index 088565fa..d004f8b4 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -31,6 +31,7 @@ func RegisterUserRoutes(
{
notifyEmail.POST("/send-code", h.User.SendNotifyEmailCode)
notifyEmail.POST("/verify", h.User.VerifyNotifyEmail)
+ notifyEmail.PUT("/toggle", h.User.ToggleNotifyEmail)
notifyEmail.DELETE("", h.User.RemoveNotifyEmail)
}
diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go
index c2e96df1..60cb6233 100644
--- a/backend/internal/service/api_key_auth_cache.go
+++ b/backend/internal/service/api_key_auth_cache.go
@@ -34,6 +34,14 @@ type APIKeyAuthUserSnapshot struct {
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
+
+ // Balance notification fields (required for CheckBalanceAfterDeduction)
+ Email string `json:"email"`
+ Username string `json:"username"`
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
+ BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
}
// APIKeyAuthGroupSnapshot 分组快照
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index 8069ed4f..711090c2 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -13,7 +13,7 @@ import (
"github.com/dgraph-io/ristretto"
)
-const apiKeyAuthSnapshotVersion = 3
+const apiKeyAuthSnapshotVersion = 4 // v4: added balance notification fields to UserSnapshot
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -219,11 +219,17 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{
- ID: apiKey.User.ID,
- Status: apiKey.User.Status,
- Role: apiKey.User.Role,
- Balance: apiKey.User.Balance,
- Concurrency: apiKey.User.Concurrency,
+ ID: apiKey.User.ID,
+ Status: apiKey.User.Status,
+ Role: apiKey.User.Role,
+ Balance: apiKey.User.Balance,
+ Concurrency: apiKey.User.Concurrency,
+ Email: apiKey.User.Email,
+ Username: apiKey.User.Username,
+ BalanceNotifyEnabled: apiKey.User.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
},
}
if apiKey.Group != nil {
@@ -274,11 +280,17 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
User: &User{
- ID: snapshot.User.ID,
- Status: snapshot.User.Status,
- Role: snapshot.User.Role,
- Balance: snapshot.User.Balance,
- Concurrency: snapshot.User.Concurrency,
+ ID: snapshot.User.ID,
+ Status: snapshot.User.Status,
+ Role: snapshot.User.Role,
+ Balance: snapshot.User.Balance,
+ Concurrency: snapshot.User.Concurrency,
+ Email: snapshot.User.Email,
+ Username: snapshot.User.Username,
+ BalanceNotifyEnabled: snapshot.User.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
},
}
if snapshot.Group != nil {
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 68202958..ba1c7037 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -176,13 +176,38 @@ func (s *BalanceNotifyService) isAccountQuotaNotifyEnabled(ctx context.Context)
return val == "true"
}
-// getAccountQuotaNotifyEmails reads admin notification emails from settings.
+// getAccountQuotaNotifyEmails reads admin notification emails from settings,
+// filtering out disabled entries. Entries with email="" are resolved to the first admin's email.
func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails)
if err != nil || strings.TrimSpace(raw) == "" || raw == "[]" {
return nil
}
- return parseJSONStringArray(raw)
+
+ entries := ParseNotifyEmails(raw)
+ if len(entries) == 0 {
+ return nil
+ }
+
+ var recipients []string
+ seen := make(map[string]bool)
+ for _, entry := range entries {
+ if entry.Disabled {
+ continue
+ }
+ email := strings.TrimSpace(entry.Email)
+ // email="" placeholder is not resolved here; admin should configure actual emails
+ if email == "" {
+ continue
+ }
+ lower := strings.ToLower(email)
+ if seen[lower] {
+ continue
+ }
+ seen[lower] = true
+ recipients = append(recipients, email)
+ }
+ return recipients
}
// getSiteName reads site name from settings with fallback.
@@ -194,18 +219,36 @@ func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
return name
}
-// collectBalanceNotifyRecipients collects all email recipients for balance notifications.
+// collectBalanceNotifyRecipients collects all non-disabled email recipients for balance notifications.
+// Entries with email="" are resolved to the user's primary email.
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
var recipients []string
- if user.Email != "" {
+ seen := make(map[string]bool)
+
+ for _, entry := range user.BalanceNotifyExtraEmails {
+ if entry.Disabled {
+ continue
+ }
+ email := strings.TrimSpace(entry.Email)
+ if email == "" {
+ email = user.Email // Resolve primary email placeholder
+ }
+ if email == "" {
+ continue
+ }
+ lower := strings.ToLower(email)
+ if seen[lower] {
+ continue
+ }
+ seen[lower] = true
+ recipients = append(recipients, email)
+ }
+
+ // If no entries exist at all (legacy/empty), fall back to user's primary email
+ if len(user.BalanceNotifyExtraEmails) == 0 && user.Email != "" {
recipients = append(recipients, user.Email)
}
- for _, extra := range user.BalanceNotifyExtraEmails {
- email := strings.TrimSpace(extra)
- if email != "" && !strings.EqualFold(email, user.Email) {
- recipients = append(recipients, email)
- }
- }
+
return recipients
}
diff --git a/backend/internal/service/notify_email_entry.go b/backend/internal/service/notify_email_entry.go
new file mode 100644
index 00000000..3caf689f
--- /dev/null
+++ b/backend/internal/service/notify_email_entry.go
@@ -0,0 +1,107 @@
+package service
+
+import (
+ "encoding/json"
+ "strings"
+)
+
+// NotifyEmailEntry represents a notification email with enable/disable and verification state.
+// Email="" is a placeholder for the "primary email" (user's registration email or first admin email).
+type NotifyEmailEntry struct {
+ Email string `json:"email"`
+ Disabled bool `json:"disabled"`
+ Verified bool `json:"verified"`
+}
+
+// parseNotifyEmails parses a JSON string into []NotifyEmailEntry.
+// It auto-detects the format:
+// - Old format ["email1","email2"] → converted to [{email, disabled:false, verified:true}, ...]
+// - New format [{email,disabled,verified}, ...] → parsed directly
+//
+// Returns nil on empty/invalid input.
+func ParseNotifyEmails(raw string) []NotifyEmailEntry {
+ raw = strings.TrimSpace(raw)
+ if raw == "" || raw == "[]" {
+ return nil
+ }
+
+ // Try parsing as new format first (array of objects)
+ var entries []NotifyEmailEntry
+ if err := json.Unmarshal([]byte(raw), &entries); err == nil && len(entries) > 0 {
+ // Verify it's actually the new format by checking the first element
+ // json.Unmarshal into []NotifyEmailEntry succeeds even for ["string"]
+ // because it tries to fit "string" into NotifyEmailEntry and gets zero values.
+ // We need to detect old format explicitly.
+ if !isOldStringArrayFormat(raw) {
+ return entries
+ }
+ }
+
+ // Try parsing as old format (array of strings)
+ var emails []string
+ if err := json.Unmarshal([]byte(raw), &emails); err == nil {
+ result := make([]NotifyEmailEntry, 0, len(emails))
+ for _, e := range emails {
+ e = strings.TrimSpace(e)
+ if e != "" {
+ result = append(result, NotifyEmailEntry{
+ Email: e,
+ Disabled: false,
+ Verified: false, // Old format emails default to unverified
+ })
+ }
+ }
+ return result
+ }
+
+ return nil
+}
+
+// isOldStringArrayFormat checks if the JSON is a string array like ["email1","email2"].
+func isOldStringArrayFormat(raw string) bool {
+ var arr []json.RawMessage
+ if err := json.Unmarshal([]byte(raw), &arr); err != nil || len(arr) == 0 {
+ return false
+ }
+ // Check if first element starts with a quote (string) vs { (object)
+ first := strings.TrimSpace(string(arr[0]))
+ return len(first) > 0 && first[0] == '"'
+}
+
+// marshalNotifyEmails serializes []NotifyEmailEntry to JSON string.
+func MarshalNotifyEmails(entries []NotifyEmailEntry) string {
+ if len(entries) == 0 {
+ return "[]"
+ }
+ data, err := json.Marshal(entries)
+ if err != nil {
+ return "[]"
+ }
+ return string(data)
+}
+
+// filterEnabledEmails returns only non-disabled email addresses from entries.
+// Empty email placeholders are skipped (caller should resolve them separately).
+func FilterEnabledEmails(entries []NotifyEmailEntry) []string {
+ var result []string
+ for _, e := range entries {
+ if e.Disabled {
+ continue
+ }
+ email := strings.TrimSpace(e.Email)
+ if email != "" {
+ result = append(result, email)
+ }
+ }
+ return result
+}
+
+// isPrimaryDisabled checks if the primary email placeholder (email="") exists and is disabled.
+func IsPrimaryDisabled(entries []NotifyEmailEntry) bool {
+ for _, e := range entries {
+ if e.Email == "" {
+ return e.Disabled
+ }
+ }
+ return false // No primary placeholder = not disabled
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 773b84ba..0267040d 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -1272,13 +1272,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Account quota notification
result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true"
if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" {
- var emails []string
- if err := json.Unmarshal([]byte(raw), &emails); err == nil {
- result.AccountQuotaNotifyEmails = emails
- }
+ result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw)
}
if result.AccountQuotaNotifyEmails == nil {
- result.AccountQuotaNotifyEmails = []string{}
+ result.AccountQuotaNotifyEmails = []NotifyEmailEntry{}
}
return result
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 274ec792..1346372d 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -113,7 +113,7 @@ type SystemSettings struct {
// Account quota notification
AccountQuotaNotifyEnabled bool
- AccountQuotaNotifyEmails []string
+ AccountQuotaNotifyEmails []NotifyEmailEntry
}
type DefaultSubscriptionSetting struct {
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index 4ca31adc..d3d8c954 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -34,7 +34,7 @@ type User struct {
BalanceNotifyEnabled bool
BalanceNotifyThresholdType string // "fixed" (default) | "percentage"
BalanceNotifyThreshold *float64
- BalanceNotifyExtraEmails []string
+ BalanceNotifyExtraEmails []NotifyEmailEntry
TotalRecharged float64
APIKeys []APIKey
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index e6b9a210..6b75140f 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -18,7 +18,7 @@ var (
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
)
-const maxNotifyExtraEmails = 5
+const maxNotifyEmails = 3 // Total limit: primary (email="") + up to 2 extra
// UserListFilters contains all filter options for listing users
type UserListFilters struct {
@@ -338,17 +338,21 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64,
// Check if already exists
for _, e := range user.BalanceNotifyExtraEmails {
- if strings.EqualFold(e, email) {
+ if strings.EqualFold(e.Email, email) {
return nil // Already added
}
}
- // Check limit
- if len(user.BalanceNotifyExtraEmails) >= maxNotifyExtraEmails {
- return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d extra notification emails allowed", maxNotifyExtraEmails))
+ // Check limit (total includes primary email="" placeholder + extra emails)
+ if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails {
+ return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails))
}
- user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, email)
+ user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{
+ Email: email,
+ Disabled: false,
+ Verified: true,
+ })
return s.userRepo.Update(ctx, user)
}
@@ -359,9 +363,9 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email
return err
}
- filtered := make([]string, 0, len(user.BalanceNotifyExtraEmails))
+ filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
for _, e := range user.BalanceNotifyExtraEmails {
- if !strings.EqualFold(e, email) {
+ if !strings.EqualFold(e.Email, email) {
filtered = append(filtered, e)
}
}
@@ -369,6 +373,28 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email
return s.userRepo.Update(ctx, user)
}
+// ToggleNotifyEmail toggles the disabled state of a notification email entry.
+func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email string, disabled bool) error {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return err
+ }
+
+ found := false
+ for i, e := range user.BalanceNotifyExtraEmails {
+ if strings.EqualFold(e.Email, email) {
+ user.BalanceNotifyExtraEmails[i].Disabled = disabled
+ found = true
+ break
+ }
+ }
+ if !found {
+ return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
+ }
+
+ return s.userRepo.Update(ctx, user)
+}
+
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
func buildNotifyVerifyEmailBody(code, siteName string) string {
return fmt.Sprintf(`
diff --git a/backend/migrations/104_migrate_notify_emails_to_struct.sql b/backend/migrations/104_migrate_notify_emails_to_struct.sql
new file mode 100644
index 00000000..4356da4f
--- /dev/null
+++ b/backend/migrations/104_migrate_notify_emails_to_struct.sql
@@ -0,0 +1,35 @@
+-- Migrate notification email lists from old []string format to new []NotifyEmailEntry format
+-- Old: ["a@x.com", "b@x.com"]
+-- New: [{"email":"a@x.com","disabled":false,"verified":true}, ...]
+-- Existing emails are marked as verified=false (unverified), disabled=false (enabled)
+
+-- 1. User balance notification emails
+UPDATE users
+SET balance_notify_extra_emails = (
+ SELECT COALESCE(
+ jsonb_agg(jsonb_build_object('email', elem::text, 'disabled', false, 'verified', false)),
+ '[]'::jsonb
+ )::text
+ FROM jsonb_array_elements_text(balance_notify_extra_emails::jsonb) AS elem
+)
+WHERE balance_notify_extra_emails IS NOT NULL
+ AND balance_notify_extra_emails <> '[]'
+ AND balance_notify_extra_emails <> ''
+ AND (balance_notify_extra_emails::jsonb -> 0) IS NOT NULL
+ AND jsonb_typeof(balance_notify_extra_emails::jsonb -> 0) = 'string';
+
+-- 2. Admin account quota notification emails
+UPDATE settings
+SET value = (
+ SELECT COALESCE(
+ jsonb_agg(jsonb_build_object('email', elem::text, 'disabled', false, 'verified', false)),
+ '[]'::jsonb
+ )::text
+ FROM jsonb_array_elements_text(value::jsonb) AS elem
+)
+WHERE key = 'account_quota_notify_emails'
+ AND value IS NOT NULL
+ AND value <> '[]'
+ AND value <> ''
+ AND (value::jsonb -> 0) IS NOT NULL
+ AND jsonb_typeof(value::jsonb -> 0) = 'string';
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 5c5de2d1..230232df 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -4,7 +4,7 @@
*/
import { apiClient } from '../client'
-import type { CustomMenuItem, CustomEndpoint } from '@/types'
+import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from '@/types'
export interface DefaultSubscriptionSetting {
group_id: number
@@ -139,7 +139,7 @@ export interface SystemSettings {
balance_low_notify_enabled: boolean
balance_low_notify_threshold: number
account_quota_notify_enabled: boolean
- account_quota_notify_emails: string[]
+ account_quota_notify_emails: NotifyEmailEntry[]
}
export interface UpdateSettingsRequest {
@@ -243,7 +243,7 @@ export interface UpdateSettingsRequest {
balance_low_notify_enabled?: boolean
balance_low_notify_threshold?: number
account_quota_notify_enabled?: boolean
- account_quota_notify_emails?: string[]
+ account_quota_notify_emails?: NotifyEmailEntry[]
}
/**
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index 9ef0f59c..cd648270 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -4,7 +4,7 @@
*/
import { apiClient } from './client'
-import type { User, ChangePasswordRequest } from '@/types'
+import type { User, ChangePasswordRequest, NotifyEmailEntry } from '@/types'
/**
* Get current user profile
@@ -24,7 +24,7 @@ export async function updateProfile(profile: {
username?: string
balance_notify_enabled?: boolean
balance_notify_threshold?: number | null
- balance_notify_extra_emails?: string[]
+ balance_notify_extra_emails?: NotifyEmailEntry[]
}): Promise {
const { data } = await apiClient.put('/user', profile)
return data
@@ -73,13 +73,24 @@ export async function removeNotifyEmail(email: string): Promise {
await apiClient.delete('/user/notify-email', { data: { email } })
}
+/**
+ * Toggle a notify email's disabled state
+ * @param email - Email address (empty string for primary email placeholder)
+ * @param disabled - Whether to disable the email
+ */
+export async function toggleNotifyEmail(email: string, disabled: boolean): Promise {
+ const { data } = await apiClient.put('/user/notify-email/toggle', { email, disabled })
+ return data
+}
+
export const userAPI = {
getProfile,
updateProfile,
changePassword,
sendNotifyEmailCode,
verifyNotifyEmail,
- removeNotifyEmail
+ removeNotifyEmail,
+ toggleNotifyEmail
}
export default userAPI
diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
index 589cc9e8..1d88ad82 100644
--- a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
+++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue
@@ -45,23 +45,26 @@
-
+
{{ t('profile.balanceNotify.extraEmails') }}
-
- {{ userEmail }}
- {{ t('profile.balanceNotify.primaryEmail') }}
-
-
-
-
-
-
+
-
{{ email }}
-
-
+
+
+
+
+
+
+ {{ entry.email === '' ? userEmail : entry.email }}
+
+
+
+ {{ t('profile.balanceNotify.primaryEmail') }}
+ {{ t('profile.balanceNotify.unverified') }}
+
{{ t('profile.balanceNotify.removeEmail') }}
@@ -100,8 +103,8 @@
-
-
+
+
+
+ {{ t('profile.balanceNotify.maxEmailsReached') }}
+
@@ -124,12 +130,15 @@
From 11c4606874d1ade3219ce75bf405463118998172 Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 13 Apr 2026 02:28:31 +0800
Subject: [PATCH 40/88] fix(channel): use upstream model for account stats
pricing and remove channel pricing fallback
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- resolveAccountStatsCost now uses the final upstream model (after
account-level mapping) to match custom pricing rules, fixing the
issue where requested model (e.g. claude-sonnet-4-5) didn't match
rules configured for upstream model (e.g. claude-opus-4-6)
- Remove tryChannelPricing fallback — only custom rules are applied,
unmatched requests use default formula (total_cost × rate)
- Remove unused billingService and serviceTier parameters
- Update description: "启用后将支持自定义账号统计的模型价格"
---
.../internal/service/account_stats_pricing.go | 38 ++++---------------
backend/internal/service/gateway_service.go | 13 ++++---
.../service/openai_gateway_service.go | 12 ++++--
frontend/src/i18n/locales/en.ts | 2 +-
frontend/src/i18n/locales/zh.ts | 2 +-
5 files changed, 25 insertions(+), 42 deletions(-)
diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go
index 86f98a12..4a896d9f 100644
--- a/backend/internal/service/account_stats_pricing.go
+++ b/backend/internal/service/account_stats_pricing.go
@@ -8,23 +8,18 @@ import (
// resolveAccountStatsCost 计算账号统计定价费用。
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
-//
-// 匹配优先级(先命中为准):
-// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历)
-// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时)
-// 3. nil → 走默认公式
+// 仅匹配自定义规则(AccountStatsPricingRules),按数组顺序先命中为准。
+// upstreamModel 是最终发往上游的模型 ID,用于匹配自定义规则中的模型定价。
func resolveAccountStatsCost(
ctx context.Context,
channelService *ChannelService,
- billingService *BillingService,
accountID int64,
groupID int64,
- billingModel string,
+ upstreamModel string,
tokens UsageTokens,
requestCount int,
- serviceTier string,
) *float64 {
- if channelService == nil || billingService == nil {
+ if channelService == nil || upstreamModel == "" {
return nil
}
channel, err := channelService.GetChannelForGroup(ctx, groupID)
@@ -33,22 +28,15 @@ func resolveAccountStatsCost(
}
platform := channelService.GetGroupPlatform(ctx, groupID)
- modelLower := strings.ToLower(billingModel)
-
- // 优先级 1:自定义规则
- if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil {
- return cost
- }
-
- // 优先级 2:渠道已有模型定价
- return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount)
+ return tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount)
}
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
func tryCustomRules(
channel *Channel, accountID, groupID int64,
- platform, modelLower string, tokens UsageTokens, requestCount int,
+ platform, model string, tokens UsageTokens, requestCount int,
) *float64 {
+ modelLower := strings.ToLower(model)
for _, rule := range channel.AccountStatsPricingRules {
if !matchAccountStatsRule(&rule, accountID, groupID) {
continue
@@ -62,18 +50,6 @@ func tryCustomRules(
return nil
}
-// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。
-func tryChannelPricing(
- ctx context.Context, channelService *ChannelService,
- groupID int64, billingModel string, tokens UsageTokens, requestCount int,
-) *float64 {
- pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel)
- if pricing == nil {
- return nil
- }
- return calculateStatsCost(pricing, tokens, requestCount)
-}
-
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index d68a7771..70dd9b52 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -7581,11 +7581,15 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
- // 计算账号统计定价费用
+ // 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
+ upstreamModel := result.UpstreamModel
+ if upstreamModel == "" {
+ upstreamModel = result.Model
+ }
usageLog.AccountStatsCost = resolveAccountStatsCost(
- ctx, s.channelService, s.billingService,
- account.ID, *apiKey.GroupID, billingModel,
+ ctx, s.channelService,
+ account.ID, *apiKey.GroupID, upstreamModel,
UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
@@ -7593,8 +7597,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
},
- 1, // requestCount
- "", // serviceTier: Anthropic 平台不使用 service tier
+ 1, // requestCount
)
}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 70abd4ce..98258cd0 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -4573,12 +4573,16 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
- // 计算账号统计定价费用
+ // 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
+ statsModel := result.UpstreamModel
+ if statsModel == "" {
+ statsModel = result.Model
+ }
usageLog.AccountStatsCost = resolveAccountStatsCost(
- ctx, s.channelService, s.billingService,
- account.ID, *apiKey.GroupID, billingModel,
- tokens, 1, serviceTier,
+ ctx, s.channelService,
+ account.ID, *apiKey.GroupID, statsModel,
+ tokens, 1,
)
}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index 07ae0c3d..f45a02f6 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -1877,7 +1877,7 @@ export default {
pricingEntry: 'Pricing Entry',
noModels: 'No models added',
applyPricingToAccountStats: 'Apply Pricing to Account Stats',
- applyPricingToAccountStatsDesc: 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.',
+ applyPricingToAccountStatsDesc: 'When enabled, custom account stats model pricing rules will be applied.',
accountStatsPricingRules: 'Custom Account Stats Pricing Rules',
addRule: 'Add Rule',
noRulesConfigured: 'No custom rules configured. Channel model pricing above will be used.',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 64d42c12..61d43a37 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -1956,7 +1956,7 @@ export default {
pricingEntry: '定价配置',
noModels: '未添加模型',
applyPricingToAccountStats: '应用模型定价到账号统计',
- applyPricingToAccountStatsDesc: '启用后,账号统计费用将使用渠道模型定价计算。账号自身的统计倍率仍然生效。',
+ applyPricingToAccountStatsDesc: '启用后将支持自定义账号统计的模型价格',
accountStatsPricingRules: '自定义账号统计定价规则',
addRule: '添加规则',
noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。',
From 1262654d9785673cfda3071a714f3a61734933e7 Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 13 Apr 2026 11:37:08 +0800
Subject: [PATCH 41/88] feat: WebSearch tri-state, account stats pricing fix,
quota cache fix, usage tooltip
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
WebSearch tri-state switch:
- Account-level web_search_emulation changed from bool to tri-state
string: "default" (follow channel) / "enabled" / "disabled"
- shouldEmulateWebSearch checks channel config when account is "default"
- SQL migration converts old bool values
- Frontend select replaces toggle in Edit/CreateAccountModal
Account stats pricing:
- resolveAccountStatsCost uses upstream model (post-mapping) for matching
- Priority: custom rules → model pricing file (when toggle on) → default
- Custom rules always configurable, independent of toggle
- Account ID field changed to searchable selector filtered by platform
- Description updated to reflect new behavior
Quota notification cache fix:
- CheckAccountQuotaAfterIncrement fetches real-time account from DB
- Reconstructs pre-increment usage for accurate threshold crossing detection
- New AccountQuotaReader interface (minimal: GetByID only)
Usage tooltip:
- Per-request/image billing shows per-request price instead of $0 token price
- Token billing continues to show input/output price per million tokens
---
backend/cmd/server/wire_gen.go | 2 +-
backend/internal/handler/gateway_handler.go | 3 +
backend/internal/service/account.go | 29 +++-
.../internal/service/account_stats_pricing.go | 41 ++++-
.../service/balance_notify_service.go | 43 +++++-
backend/internal/service/gateway_request.go | 3 +
backend/internal/service/gateway_service.go | 4 +-
.../service/gateway_websearch_emulation.go | 24 ++-
.../service/openai_gateway_service.go | 2 +-
backend/internal/service/wire.go | 4 +-
...igrate_websearch_emulation_to_tristate.sql | 11 ++
.../components/account/CreateAccountModal.vue | 21 +--
.../components/account/EditAccountModal.vue | 27 +++-
.../src/components/admin/usage/UsageTable.vue | 22 ++-
frontend/src/i18n/locales/en.ts | 12 +-
frontend/src/i18n/locales/zh.ts | 12 +-
frontend/src/views/admin/ChannelsView.vue | 143 ++++++++++++++++--
frontend/src/views/user/UsageView.vue | 22 ++-
18 files changed, 346 insertions(+), 79 deletions(-)
create mode 100644 backend/migrations/105_migrate_websearch_emulation_to_tristate.sql
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 24ee02fd..69daeecf 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -176,7 +176,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
- balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository)
+ balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 59619d50..8ec54420 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
+ // 设置请求所属分组 ID(用于渠道级功能判断,如 WebSearch 模拟)
+ parsedReq.GroupID = apiKey.GroupID
+
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 6e5a768f..4d933986 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -1169,15 +1169,30 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled
}
-// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。
-// 字段:accounts.extra.web_search_emulation。
-// 字段缺失或类型不正确时,按 false(关闭)处理。
-func (a *Account) IsWebSearchEmulationEnabled() bool {
+// WebSearch 模拟三态常量
+const (
+ WebSearchModeDefault = "default" // 跟随渠道配置
+ WebSearchModeEnabled = "enabled" // 强制开启
+ WebSearchModeDisabled = "disabled" // 强制关闭
+)
+
+// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
+// 三态:default(跟随渠道)/ enabled(强制开启)/ disabled(强制关闭)。
+// 旧 bool 值需通过 SQL 迁移脚本转换,Go 代码不做兼容。
+func (a *Account) GetWebSearchEmulationMode() string {
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
- return false
+ return WebSearchModeDefault
+ }
+ mode, ok := a.Extra[featureKeyWebSearchEmulation].(string)
+ if !ok {
+ return WebSearchModeDefault
+ }
+ switch mode {
+ case WebSearchModeEnabled, WebSearchModeDisabled:
+ return mode
+ default:
+ return WebSearchModeDefault
}
- enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
- return ok && enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go
index 4a896d9f..e88f7f8c 100644
--- a/backend/internal/service/account_stats_pricing.go
+++ b/backend/internal/service/account_stats_pricing.go
@@ -8,11 +8,17 @@ import (
// resolveAccountStatsCost 计算账号统计定价费用。
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
-// 仅匹配自定义规则(AccountStatsPricingRules),按数组顺序先命中为准。
-// upstreamModel 是最终发往上游的模型 ID,用于匹配自定义规则中的模型定价。
+//
+// 优先级(先命中为准):
+// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
+// 2. ApplyPricingToAccountStats 启用时,用模型定价文件(LiteLLM)中上游模型的标准价格计算
+// 3. nil → 走默认公式
+//
+// upstreamModel 是最终发往上游的模型 ID。
func resolveAccountStatsCost(
ctx context.Context,
channelService *ChannelService,
+ billingService *BillingService,
accountID int64,
groupID int64,
upstreamModel string,
@@ -23,12 +29,39 @@ func resolveAccountStatsCost(
return nil
}
channel, err := channelService.GetChannelForGroup(ctx, groupID)
- if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
+ if err != nil || channel == nil {
return nil
}
platform := channelService.GetGroupPlatform(ctx, groupID)
- return tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount)
+
+ // 优先级 1:自定义规则(始终尝试)
+ if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
+ return cost
+ }
+
+ // 优先级 2:模型定价文件(LiteLLM/fallback)中上游模型的标准价格
+ if channel.ApplyPricingToAccountStats && billingService != nil {
+ return tryModelFilePricing(billingService, upstreamModel, tokens)
+ }
+
+ return nil
+}
+
+// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。
+func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
+ pricing, err := billingService.GetModelPricing(model)
+ if err != nil || pricing == nil {
+ return nil
+ }
+ cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
+ float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
+ float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
+ float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken
+ if cost <= 0 {
+ return nil
+ }
+ return &cost
}
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 053451e1..23411ed5 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -27,17 +27,24 @@ var quotaDimLabels = map[string]string{
quotaDimTotal: "总限额 / Total",
}
+// AccountQuotaReader provides read access to account quota data.
+type AccountQuotaReader interface {
+ GetByID(ctx context.Context, id int64) (*Account, error)
+}
+
// BalanceNotifyService handles balance and quota threshold notifications.
type BalanceNotifyService struct {
emailService *EmailService
settingRepo SettingRepository
+ accountRepo AccountQuotaReader
}
// NewBalanceNotifyService creates a new BalanceNotifyService.
-func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService {
+func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountQuotaReader) *BalanceNotifyService {
return &BalanceNotifyService{
emailService: emailService,
settingRepo: settingRepo,
+ accountRepo: accountRepo,
}
}
@@ -110,7 +117,7 @@ func buildQuotaDims(account *Account) []quotaDim {
}
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
-// The account's Extra fields contain pre-increment usage values.
+// It fetches real-time quota usage from DB to avoid stale snapshot values.
func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64) {
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
return
@@ -123,8 +130,29 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte
return
}
+ freshAccount := s.fetchFreshAccount(ctx, account)
siteName := s.getSiteName(ctx)
- for _, dim := range buildQuotaDims(account) {
+ s.checkQuotaDimCrossings(freshAccount, cost, adminEmails, siteName)
+}
+
+// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
+func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *Account) *Account {
+ if s.accountRepo == nil {
+ return snapshot
+ }
+ fresh, err := s.accountRepo.GetByID(ctx, snapshot.ID)
+ if err != nil {
+ slog.Warn("failed to fetch fresh account for quota notify, using snapshot",
+ "account_id", snapshot.ID, "error", err)
+ return snapshot
+ }
+ return fresh
+}
+
+// checkQuotaDimCrossings iterates quota dimensions and sends alerts for threshold crossings.
+// freshAccount has post-increment values; oldUsed is reconstructed as freshUsed - cost.
+func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cost float64, adminEmails []string, siteName string) {
+ for _, dim := range buildQuotaDims(freshAccount) {
if !dim.enabled || dim.threshold <= 0 {
continue
}
@@ -132,9 +160,12 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte
if effectiveThreshold <= 0 {
continue
}
- newUsed := dim.oldUsed + cost
- if dim.oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
- s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, effectiveThreshold, siteName)
+ // dim.oldUsed is actually the post-increment value from fresh DB data;
+ // reconstruct pre-increment value to detect threshold crossing.
+ newUsed := dim.oldUsed
+ oldUsed := dim.oldUsed - cost
+ if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
+ s.asyncSendQuotaAlert(adminEmails, freshAccount.Name, dim, newUsed, effectiveThreshold, siteName)
}
}
}
diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go
index e2badfed..55cb2c84 100644
--- a/backend/internal/service/gateway_request.go
+++ b/backend/internal/service/gateway_request.go
@@ -75,6 +75,9 @@ type ParsedRequest struct {
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
+ // GroupID 请求所属分组 ID(来自 API Key)
+ GroupID *int64
+
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
OnUpstreamAccepted func()
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 70dd9b52..5267156d 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -3789,7 +3789,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
- if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
+ if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
@@ -7588,7 +7588,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
upstreamModel = result.Model
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
- ctx, s.channelService,
+ ctx, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, upstreamModel,
UsageTokens{
InputTokens: result.Usage.InputTokens,
diff --git a/backend/internal/service/gateway_websearch_emulation.go b/backend/internal/service/gateway_websearch_emulation.go
index b3a4aa69..0d0b5480 100644
--- a/backend/internal/service/gateway_websearch_emulation.go
+++ b/backend/internal/service/gateway_websearch_emulation.go
@@ -49,10 +49,9 @@ func getWebSearchManager() *websearch.Manager {
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
-// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
-// Note: channel-level control is enforced via the account's extra field; the channel toggle
-// in the admin UI sets the account's flag for all accounts in that channel's groups.
-func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
+// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
+// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
+func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
if getWebSearchManager() == nil {
return false
}
@@ -62,10 +61,23 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
return false
}
- if !account.IsWebSearchEmulationEnabled() {
+
+ mode := account.GetWebSearchEmulationMode()
+ switch mode {
+ case WebSearchModeEnabled:
+ return true
+ case WebSearchModeDisabled:
return false
+ default: // "default" → follow channel config
+ if groupID == nil || s.channelService == nil {
+ return false
+ }
+ ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
+ if err != nil || ch == nil {
+ return false
+ }
+ return ch.IsWebSearchEmulationEnabled(account.Platform)
}
- return true
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 98258cd0..e060b981 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -4580,7 +4580,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
statsModel = result.Model
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
- ctx, s.channelService,
+ ctx, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, statsModel,
tokens, 1,
)
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index b4e33039..9f33c46a 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -476,8 +476,8 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
}
// ProvideBalanceNotifyService creates BalanceNotifyService
-func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService {
- return NewBalanceNotifyService(emailService, settingRepo)
+func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService {
+ return NewBalanceNotifyService(emailService, settingRepo, accountRepo)
}
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
diff --git a/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql b/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql
new file mode 100644
index 00000000..745e58df
--- /dev/null
+++ b/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql
@@ -0,0 +1,11 @@
+-- Convert old boolean web_search_emulation to tri-state string
+-- true → "enabled", false → remove key (becomes "default")
+UPDATE accounts
+SET extra = (extra - 'web_search_emulation') || jsonb_build_object('web_search_emulation', 'enabled')
+WHERE extra ? 'web_search_emulation'
+ AND extra->>'web_search_emulation' = 'true';
+
+UPDATE accounts
+SET extra = extra - 'web_search_emulation'
+WHERE extra ? 'web_search_emulation'
+ AND extra->>'web_search_emulation' = 'false';
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index e0dbce61..e83e061e 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -2337,7 +2337,11 @@
{{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}
-
+
+ {{ t('admin.accounts.anthropic.webSearchDefault') }}
+ {{ t('admin.accounts.anthropic.webSearchEnabled') }}
+ {{ t('admin.accounts.anthropic.webSearchDisabled') }}
+
@@ -2846,7 +2850,6 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
import Icon from '@/components/icons/Icon.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
-import Toggle from '@/components/common/Toggle.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
@@ -2997,7 +3000,7 @@ const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
-const webSearchEmulationEnabled = ref(false)
+const webSearchEmulationMode = ref('default')
const webSearchGlobalEnabled = ref(false)
// Load web search global state once
@@ -3331,7 +3334,7 @@ watch(
}
if (newPlatform !== 'anthropic') {
anthropicPassthroughEnabled.value = false
- webSearchEmulationEnabled.value = false
+ webSearchEmulationMode.value = 'default'
}
// Reset OAuth states
oauth.resetState()
@@ -3351,7 +3354,7 @@ watch(
}
if (platform !== 'anthropic' || category !== 'apikey') {
anthropicPassthroughEnabled.value = false
- webSearchEmulationEnabled.value = false
+ webSearchEmulationMode.value = 'default'
}
}
)
@@ -3716,7 +3719,7 @@ const resetForm = () => {
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
anthropicPassthroughEnabled.value = false
- webSearchEmulationEnabled.value = false
+ webSearchEmulationMode.value = 'default'
// Reset quota control state
windowCostEnabled.value = false
windowCostLimit.value = null
@@ -3804,10 +3807,10 @@ const buildAnthropicExtra = (base?: Record): Record 0 ? extra : undefined
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index abb9569e..74e5fc1f 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -1161,7 +1161,11 @@
{{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}
-
+
+ {{ t('admin.accounts.anthropic.webSearchDefault') }}
+ {{ t('admin.accounts.anthropic.webSearchEnabled') }}
+ {{ t('admin.accounts.anthropic.webSearchDisabled') }}
+
@@ -1844,7 +1848,6 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
import Icon from '@/components/icons/Icon.vue'
import ProxySelector from '@/components/common/ProxySelector.vue'
-import Toggle from '@/components/common/Toggle.vue'
import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
@@ -1986,7 +1989,7 @@ const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
-const webSearchEmulationEnabled = ref(false)
+const webSearchEmulationMode = ref('default')
const webSearchGlobalEnabled = ref(false)
// Load web search global state once
@@ -2171,7 +2174,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
anthropicPassthroughEnabled.value = false
- webSearchEmulationEnabled.value = false
+ webSearchEmulationMode.value = 'default'
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, {
@@ -2192,7 +2195,15 @@ const syncFormFromAccount = (newAccount: Account | null) => {
}
if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') {
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
- webSearchEmulationEnabled.value = extra?.web_search_emulation === true
+ // 三态:string "default"/"enabled"/"disabled",向后兼容旧 bool
+ const wsVal = extra?.web_search_emulation
+ if (wsVal === 'enabled' || wsVal === 'disabled') {
+ webSearchEmulationMode.value = wsVal
+ } else if (wsVal === true) {
+ webSearchEmulationMode.value = 'enabled'
+ } else {
+ webSearchEmulationMode.value = 'default'
+ }
}
// Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above)
@@ -3180,10 +3191,10 @@ const handleSubmit = async () => {
} else {
delete newExtra.anthropic_passthrough
}
- if (webSearchEmulationEnabled.value) {
- newExtra.web_search_emulation = true
- } else {
+ if (webSearchEmulationMode.value === 'default') {
delete newExtra.web_search_emulation
+ } else {
+ newExtra.web_search_emulation = webSearchEmulationMode.value
}
updatePayload.extra = newExtra
}
diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue
index f4494e69..92c8dd34 100644
--- a/frontend/src/components/admin/usage/UsageTable.vue
+++ b/frontend/src/components/admin/usage/UsageTable.vue
@@ -279,13 +279,21 @@
{{ t('admin.usage.outputCost') }}
${{ tooltipData.output_cost.toFixed(6) }}
-
- {{ t('usage.inputTokenPrice') }}
- {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}
-
-
-
{{ t('usage.outputTokenPrice') }}
-
{{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
+
+ {{ t('usage.inputTokenPrice') }}
+ {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
+ {{ t('usage.outputTokenPrice') }}
+ {{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
+
+
+ {{ tooltipData.billing_mode === 'image' ? t('usage.imageUnitPrice') : t('usage.unitPrice') }}
+ ${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}
{{ t('admin.usage.cacheCreationCost') }}
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index f45a02f6..056bdde0 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -774,6 +774,8 @@ export default {
inputTokenPrice: 'Input price',
outputTokenPrice: 'Output price',
perMillionTokens: '/ 1M tokens',
+ unitPrice: 'Per-request price',
+ imageUnitPrice: 'Per-image price',
cacheRead: 'Read',
cacheWrite: 'Write',
serviceTier: 'Service tier',
@@ -1877,14 +1879,15 @@ export default {
pricingEntry: 'Pricing Entry',
noModels: 'No models added',
applyPricingToAccountStats: 'Apply Pricing to Account Stats',
- applyPricingToAccountStatsDesc: 'When enabled, custom account stats model pricing rules will be applied.',
+ applyPricingToAccountStatsDesc: 'When enabled, requests not matched by custom rules will use standard model pricing for account stats calculation',
accountStatsPricingRules: 'Custom Account Stats Pricing Rules',
addRule: 'Add Rule',
noRulesConfigured: 'No custom rules configured. Channel model pricing above will be used.',
ruleName: 'Rule name (optional)',
ruleGroups: 'Groups',
- ruleAccounts: 'Account IDs',
- ruleAccountsPlaceholder: 'Enter account IDs, comma-separated',
+ ruleAccounts: 'Accounts',
+ searchAccountPlaceholder: 'Search accounts...',
+ ruleAccountsHint: 'Leave empty to match all accounts',
ruleModelPricing: 'Model Pricing',
noGroupsInChannel: 'No groups selected in platform tabs above'
}
@@ -2380,6 +2383,9 @@ export default {
webSearchEmulation: 'Web Search Emulation',
webSearchEmulationDesc:
'Enable web search emulation for this API Key account. When a pure web_search request is detected, the gateway calls a third-party search API and constructs the response locally.',
+ webSearchDefault: 'Default (follow channel)',
+ webSearchEnabled: 'Enabled',
+ webSearchDisabled: 'Disabled',
},
modelRestriction: 'Model Restriction (Optional)',
modelWhitelist: 'Model Whitelist',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 61d43a37..47a1f8d4 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -778,6 +778,8 @@ export default {
inputTokenPrice: '输入单价',
outputTokenPrice: '输出单价',
perMillionTokens: '/ 1M Token',
+ unitPrice: '单次价格',
+ imageUnitPrice: '单张价格',
cacheRead: '读取',
cacheWrite: '写入',
serviceTier: '服务档位',
@@ -1956,14 +1958,15 @@ export default {
pricingEntry: '定价配置',
noModels: '未添加模型',
applyPricingToAccountStats: '应用模型定价到账号统计',
- applyPricingToAccountStatsDesc: '启用后将支持自定义账号统计的模型价格',
+ applyPricingToAccountStatsDesc: '启用后,未被自定义规则匹配的请求将使用模型定价文件中的标准价格计算账号统计费用',
accountStatsPricingRules: '自定义账号统计定价规则',
addRule: '添加规则',
noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。',
ruleName: '规则名称(可选)',
ruleGroups: '分组',
- ruleAccounts: '账号 ID',
- ruleAccountsPlaceholder: '输入账号 ID,逗号分隔',
+ ruleAccounts: '账号',
+ searchAccountPlaceholder: '搜索账号...',
+ ruleAccountsHint: '留空表示匹配所有账号',
ruleModelPricing: '模型定价',
noGroupsInChannel: '上方平台标签页中未选择分组'
}
@@ -2527,6 +2530,9 @@ export default {
webSearchEmulation: 'Web Search 模拟',
webSearchEmulationDesc:
'为该 API Key 账号启用 web search 模拟。客户端发送纯 web_search 请求时,由网关调用第三方搜索 API 并构造响应返回。',
+ webSearchDefault: '默认(跟随渠道)',
+ webSearchEnabled: '开启',
+ webSearchDisabled: '关闭',
},
modelRestriction: '模型限制(可选)',
modelWhitelist: '模型白名单',
diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue
index 2e45ba42..5be45cb5 100644
--- a/frontend/src/views/admin/ChannelsView.vue
+++ b/frontend/src/views/admin/ChannelsView.vue
@@ -413,8 +413,8 @@
-
-
+
+
{{ t('admin.channels.form.accountStatsPricingRules') }}
@@ -474,12 +474,51 @@
{{ t('admin.channels.form.ruleAccounts') }}
-
+
+
+
+ {{ getRuleAccountLabel(accountId) }}
+
+
+
+
+
+
+
+
+
+
+
+ {{ account.name }}
+ #{{ account.id }}
+
+
+
+
+ {{ t('admin.channels.form.ruleAccountsHint') }}
+
@@ -569,6 +608,7 @@ import PlatformIcon from '@/components/common/PlatformIcon.vue'
import Toggle from '@/components/common/Toggle.vue'
import PricingEntryCard from '@/components/admin/channel/PricingEntryCard.vue'
import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
+import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
const { t } = useI18n()
const appStore = useAppStore()
@@ -852,6 +892,9 @@ function addRulePricingEntry(ruleIndex: number) {
function removeAccountStatsRule(ruleIndex: number) {
form.account_stats_pricing_rules.splice(ruleIndex, 1)
+ // Clear all search state since indices shift after removal
+ ruleAccountSearchRunner.clearAll()
+ clearAllRuleAccountSearchState()
}
function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) {
@@ -863,11 +906,78 @@ function getGroupNameById(groupId: number): string {
return group ? group.name : `#${groupId}`
}
-function parseAccountIdsInput(value: string): number[] {
- return value
- .split(',')
- .map(s => parseInt(s.trim()))
- .filter(n => !isNaN(n) && n > 0)
+// ── Account search for pricing rules ──
+interface SimpleAccount { id: number; name: string }
+
+const ruleAccountSearchKeyword = ref>({})
+const ruleAccountSearchResults = ref>({})
+const showRuleAccountDropdown = ref>({})
+// Cache: account ID → name, populated when search results are selected
+const ruleAccountNameCache = ref>({})
+
+const ruleAccountSearchRunner = useKeyedDebouncedSearch({
+ delay: 300,
+ search: async (keyword, { key, signal }) => {
+ const platform = key.split('-')[0]
+ const res = await adminAPI.accounts.list(1, 20, { platform, search: keyword }, { signal })
+ return res.items.map(a => ({ id: a.id, name: a.name }))
+ },
+ onSuccess: (key, result) => { ruleAccountSearchResults.value[key] = result },
+ onError: (key) => { ruleAccountSearchResults.value[key] = [] },
+})
+
+function onRuleAccountSearchInput(platform: string, ruleIndex: number) {
+ const key = `${platform}-${ruleIndex}`
+ showRuleAccountDropdown.value[key] = true
+ ruleAccountSearchRunner.trigger(key, ruleAccountSearchKeyword.value[key] || '')
+}
+
+function onRuleAccountSearchFocus(platform: string, ruleIndex: number) {
+ const key = `${platform}-${ruleIndex}`
+ showRuleAccountDropdown.value[key] = true
+ if (!ruleAccountSearchResults.value[key]?.length) {
+ ruleAccountSearchRunner.trigger(key, ruleAccountSearchKeyword.value[key] || '')
+ }
+}
+
+function selectRuleAccount(
+ rule: { account_ids: number[] },
+ account: SimpleAccount,
+ platform: string,
+ ruleIndex: number,
+) {
+ if (!rule.account_ids.includes(account.id)) {
+ rule.account_ids.push(account.id)
+ ruleAccountNameCache.value[account.id] = account.name
+ }
+ const key = `${platform}-${ruleIndex}`
+ ruleAccountSearchKeyword.value[key] = ''
+ showRuleAccountDropdown.value[key] = false
+}
+
+function removeRuleAccount(rule: { account_ids: number[] }, accountId: number) {
+ const idx = rule.account_ids.indexOf(accountId)
+ if (idx !== -1) rule.account_ids.splice(idx, 1)
+}
+
+function getRuleAccountLabel(accountId: number): string {
+ const name = ruleAccountNameCache.value[accountId]
+ return name ? `${name} #${accountId}` : `#${accountId}`
+}
+
+function handleRuleAccountClickOutside(event: MouseEvent) {
+ const target = event.target as HTMLElement
+ if (!target.closest('.rule-account-search-container')) {
+ Object.keys(showRuleAccountDropdown.value).forEach(key => {
+ showRuleAccountDropdown.value[key] = false
+ })
+ }
+}
+
+function clearAllRuleAccountSearchState() {
+ ruleAccountSearchKeyword.value = {}
+ ruleAccountSearchResults.value = {}
+ showRuleAccountDropdown.value = {}
}
function accountStatsRulesToAPI(): AccountStatsPricingRule[] {
@@ -1093,6 +1203,9 @@ function resetForm() {
form.apply_pricing_to_account_stats = false
form.account_stats_pricing_rules = []
activeTab.value = 'basic'
+ ruleAccountSearchRunner.clearAll()
+ clearAllRuleAccountSearchState()
+ ruleAccountNameCache.value = {}
}
async function openCreateDialog() {
@@ -1313,11 +1426,15 @@ onMounted(() => {
loadChannels()
loadGroups()
loadWebSearchGlobalState()
+ document.addEventListener('click', handleRuleAccountClickOutside)
})
onUnmounted(() => {
clearTimeout(searchTimeout)
abortController?.abort()
+ document.removeEventListener('click', handleRuleAccountClickOutside)
+ ruleAccountSearchRunner.clearAll()
+ clearAllRuleAccountSearchState()
})
diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue
index 6cb367ed..2ec0fea5 100644
--- a/frontend/src/views/user/UsageView.vue
+++ b/frontend/src/views/user/UsageView.vue
@@ -447,13 +447,21 @@
{{ t('admin.usage.outputCost') }}
${{ tooltipData.output_cost.toFixed(6) }}
-
- {{ t('usage.inputTokenPrice') }}
- {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}
-
-
-
{{ t('usage.outputTokenPrice') }}
-
{{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
+
+ {{ t('usage.inputTokenPrice') }}
+ {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
+ {{ t('usage.outputTokenPrice') }}
+ {{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
+
+
+ {{ tooltipData.billing_mode === 'image' ? t('usage.imageUnitPrice') : t('usage.unitPrice') }}
+ ${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}
{{ t('admin.usage.cacheCreationCost') }}
From a68df457d80dd67b173df2a74fd69de1206c0f2e Mon Sep 17 00:00:00 2001
From: erio
Date: Mon, 13 Apr 2026 12:07:09 +0800
Subject: [PATCH 42/88] fix: address audit findings across websearch, notify,
and channel pricing
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Backend fixes:
- Fix balance notify ignoring percentage threshold type (was treating
percentage value as fixed USD amount)
- Remove dead code parseJSONStringArray
- Add ImageOutputTokens to tryModelFilePricing calculation
- Unify zero-value check: cost == 0 → cost <= 0 in calculateTokenStatsCost
- Use MarshalNotifyEmails instead of json.Marshal for consistency
- Rename quotaDim.oldUsed → currentUsed for clarity
- Extract HTML email templates to const variables (function ≤30 lines)
Test fixes:
- Rewrite account_websearch_test.go for GetWebSearchEmulationMode tri-state
- Add 6 tryModelFilePricing test cases
Frontend fixes:
- Replace hardcoded '未命名' with i18n key
- Extract getBillingModeLabel/getBillingModeBadgeClass to shared utils
- Replace inline type with imported NotifyEmailEntry
- Pass platform to AccountStats pricing rules via inferRulePlatform()
- Add billing mode constants (BILLING_MODE_TOKEN/PER_REQUEST/IMAGE)
---
.../internal/service/account_stats_pricing.go | 5 +-
.../service/account_stats_pricing_test.go | 99 +++++++++++++++++++
.../service/account_websearch_test.go | 96 ++++++++++++------
.../service/balance_notify_service.go | 79 ++++++++-------
backend/internal/service/setting_service.go | 6 +-
.../src/components/admin/usage/UsageTable.vue | 14 +--
frontend/src/i18n/locales/en.ts | 3 +-
frontend/src/i18n/locales/zh.ts | 3 +-
frontend/src/utils/billingMode.ts | 19 ++++
frontend/src/views/admin/ChannelsView.vue | 52 ++++++----
frontend/src/views/admin/SettingsView.vue | 4 +-
frontend/src/views/user/UsageView.vue | 16 +--
12 files changed, 275 insertions(+), 121 deletions(-)
create mode 100644 frontend/src/utils/billingMode.ts
diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go
index e88f7f8c..cbe9c76c 100644
--- a/backend/internal/service/account_stats_pricing.go
+++ b/backend/internal/service/account_stats_pricing.go
@@ -57,7 +57,8 @@ func tryModelFilePricing(billingService *BillingService, model string, tokens Us
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
- float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken
+ float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
+ float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
if cost <= 0 {
return nil
}
@@ -194,7 +195,7 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
- if cost == 0 {
+ if cost <= 0 {
return nil
}
return &cost
diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go
index bc3db251..bf9da978 100644
--- a/backend/internal/service/account_stats_pricing_test.go
+++ b/backend/internal/service/account_stats_pricing_test.go
@@ -428,3 +428,102 @@ func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
require.NotNil(t, result)
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
}
+
+// ---------------------------------------------------------------------------
+// tryModelFilePricing
+// ---------------------------------------------------------------------------
+
+// newTestBillingServiceWithPrices creates a BillingService with pre-populated
+// fallback prices for testing. No config or pricing service is needed.
+// The key must match what getFallbackPricing resolves to for a given model name.
+// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4".
+func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService {
+ return &BillingService{
+ fallbackPrices: prices,
+ }
+}
+
+func TestTryModelFilePricing_Success(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ },
+ })
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+func TestTryModelFilePricing_PricingNotFound(t *testing.T) {
+ // "nonexistent-model" does not match any fallback pattern
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryModelFilePricing(bs, "nonexistent-model", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_NilFallback(t *testing.T) {
+ // getFallbackPricing returns nil when key maps to nil
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": nil,
+ })
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_ZeroCost(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ },
+ })
+ tokens := UsageTokens{} // all zero tokens → cost = 0 → nil
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_WithImageOutput(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ ImageOutputPricePerToken: 0.01,
+ },
+ })
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ ImageOutputTokens: 10,
+ }
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
+ require.InDelta(t, 0.3, *result, 1e-12)
+}
+
+func TestTryModelFilePricing_WithCacheTokens(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ CacheCreationPricePerToken: 0.003,
+ CacheReadPricePerToken: 0.0005,
+ },
+ })
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ CacheReadTokens: 300,
+ }
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
+ // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
+ require.InDelta(t, 0.95, *result, 1e-12)
+}
diff --git a/backend/internal/service/account_websearch_test.go b/backend/internal/service/account_websearch_test.go
index fe742ebf..b4d23c6b 100644
--- a/backend/internal/service/account_websearch_test.go
+++ b/backend/internal/service/account_websearch_test.go
@@ -1,3 +1,5 @@
+//go:build unit
+
package service
import (
@@ -6,66 +8,98 @@ import (
"github.com/stretchr/testify/require"
)
-func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
+func TestGetWebSearchEmulationMode_Enabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
+ }
+ require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_Disabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"},
+ }
+ require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_Default(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "default"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
}
- require.True(t, a.IsWebSearchEmulationEnabled())
+ // bool is not a string, type assertion fails → default
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
-func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
+func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: false},
}
- require.False(t, a.IsWebSearchEmulationEnabled())
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
-func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) {
+func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) {
+ var a *Account
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: nil,
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_MissingField(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{},
}
- require.False(t, a.IsWebSearchEmulationEnabled())
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
-func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) {
- a := &Account{
- Platform: PlatformAnthropic,
- Type: AccountTypeAPIKey,
- Extra: map[string]any{featureKeyWebSearchEmulation: "true"},
- }
- require.False(t, a.IsWebSearchEmulationEnabled())
-}
-
-func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) {
- a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil}
- require.False(t, a.IsWebSearchEmulationEnabled())
-}
-
-func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) {
- var a *Account
- require.False(t, a.IsWebSearchEmulationEnabled())
-}
-
-func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) {
+func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) {
a := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
- Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
- require.False(t, a.IsWebSearchEmulationEnabled())
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
-func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) {
+func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
- Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
- require.False(t, a.IsWebSearchEmulationEnabled())
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go
index 23411ed5..1e4d8ff6 100644
--- a/backend/internal/service/balance_notify_service.go
+++ b/backend/internal/service/balance_notify_service.go
@@ -2,7 +2,6 @@ package service
import (
"context"
- "encoding/json"
"fmt"
"html"
"log/slog"
@@ -14,6 +13,10 @@ import (
const (
emailSendTimeout = 30 * time.Second
+ // Threshold type values
+ thresholdTypeFixed = "fixed"
+ thresholdTypePercentage = "percentage"
+
// Quota dimension labels
quotaDimDaily = "daily"
quotaDimWeekly = "weekly"
@@ -48,6 +51,15 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo
}
}
+// resolveBalanceThreshold returns the effective balance threshold.
+// For percentage type, it computes threshold = totalRecharged * percentage / 100.
+func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
+ if thresholdType == thresholdTypePercentage && totalRecharged > 0 {
+ return totalRecharged * threshold / 100
+ }
+ return threshold
+}
+
// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
// oldBalance is the balance before deduction, cost is the amount deducted.
// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
@@ -73,8 +85,13 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
return
}
+ effectiveThreshold := resolveBalanceThreshold(threshold, user.BalanceNotifyThresholdType, user.TotalRecharged)
+ if effectiveThreshold <= 0 {
+ return
+ }
+
newBalance := oldBalance - cost
- if oldBalance >= threshold && newBalance < threshold {
+ if oldBalance >= effectiveThreshold && newBalance < effectiveThreshold {
siteName := s.getSiteName(ctx)
recipients := s.collectBalanceNotifyRecipients(user)
go func() {
@@ -83,7 +100,7 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u
slog.Error("panic in balance notification", "recover", r)
}
}()
- s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName)
+ s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, effectiveThreshold, siteName)
}()
}
}
@@ -94,14 +111,14 @@ type quotaDim struct {
enabled bool
threshold float64
thresholdType string // "fixed" (default) or "percentage"
- oldUsed float64
+ currentUsed float64
limit float64
}
// resolvedThreshold returns the effective threshold value.
// For percentage type, it computes threshold = limit * percentage / 100.
func (d quotaDim) resolvedThreshold() float64 {
- if d.thresholdType == "percentage" && d.limit > 0 {
+ if d.thresholdType == thresholdTypePercentage && d.limit > 0 {
return d.limit * d.threshold / 100
}
return d.threshold
@@ -150,7 +167,7 @@ func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *
}
// checkQuotaDimCrossings iterates quota dimensions and sends alerts for threshold crossings.
-// freshAccount has post-increment values; oldUsed is reconstructed as freshUsed - cost.
+// freshAccount has post-increment values; pre-increment is reconstructed as currentUsed - cost.
func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cost float64, adminEmails []string, siteName string) {
for _, dim := range buildQuotaDims(freshAccount) {
if !dim.enabled || dim.threshold <= 0 {
@@ -160,10 +177,10 @@ func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cos
if effectiveThreshold <= 0 {
continue
}
- // dim.oldUsed is actually the post-increment value from fresh DB data;
+ // currentUsed is the post-increment value from fresh DB data;
// reconstruct pre-increment value to detect threshold crossing.
- newUsed := dim.oldUsed
- oldUsed := dim.oldUsed - cost
+ newUsed := dim.currentUsed
+ oldUsed := dim.currentUsed - cost
if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
s.asyncSendQuotaAlert(adminEmails, freshAccount.Name, dim, newUsed, effectiveThreshold, siteName)
}
@@ -309,10 +326,9 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension)
}
-// buildBalanceLowEmailBody builds HTML email for balance low notification.
-// Lines exceed 30 due to inline HTML template (not splittable).
-func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string {
- return fmt.Sprintf(`
+// balanceLowEmailTemplate is the HTML template for balance low notifications.
+// Format args: siteName, userName, userName, balance, threshold, threshold.
+const balanceLowEmailTemplate = `
@@ -344,17 +360,11 @@ func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance
-`, siteName, userName, userName, balance, threshold, threshold)
-}
+`
-// buildQuotaAlertEmailBody builds HTML email for account quota alert.
-// Lines exceed 30 due to inline HTML template (not splittable).
-func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel string, used, limit, threshold float64, siteName string) string {
- limitStr := fmt.Sprintf("$%.2f", limit)
- if limit <= 0 {
- limitStr = "无限制 / Unlimited"
- }
- return fmt.Sprintf(`
+// quotaAlertEmailTemplate is the HTML template for account quota alert notifications.
+// Format args: siteName, accountName, dimLabel, used, limitStr, threshold.
+const quotaAlertEmailTemplate = `
@@ -389,18 +399,19 @@ func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel st
-`, siteName, accountName, dimLabel, used, limitStr, threshold)
+