From 1cd033e521b02e23fabf23628303663985b34114 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 9 Mar 2026 19:24:19 +0800 Subject: [PATCH 01/88] style: apply gofmt formatting Co-Authored-By: Claude Opus 4.6 --- backend/internal/repository/gateway_cache.go | 257 +++++++++++++++++- .../service/admin_service_apikey_test.go | 72 +---- backend/internal/service/user_service_test.go | 9 +- 3 files changed, 257 insertions(+), 81 deletions(-) diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 58291b66..ec4bf40e 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,14 +2,42 @@ package repository import ( "context" + _ "embed" "fmt" + "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" ) -const stickySessionPrefix = "sticky_session:" +const ( + stickySessionPrefix = "sticky_session:" + clientAffinityPrefix = "client_affinity:" + clientAffinityReversePrefix = "client_affinity_rev:" +) + +var ( + //go:embed lua/get_affinity.lua + getAffinityLua string + //go:embed lua/update_affinity.lua + updateAffinityLua string + //go:embed lua/get_affinity_count.lua + getAffinityCountLua string + //go:embed lua/get_affinity_clients.lua + getAffinityClientsLua string + //go:embed lua/get_affinity_clients_with_scores.lua + getAffinityClientsWithScoresLua string + //go:embed lua/clear_account_affinity.lua + clearAccountAffinityLua string + + getAffinityScript = redis.NewScript(getAffinityLua) + updateAffinityScript = redis.NewScript(updateAffinityLua) + getAffinityCountScript = redis.NewScript(getAffinityCountLua) + getAffinityClientsScript = redis.NewScript(getAffinityClientsLua) + getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua) + clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua) +) type gatewayCache struct { rdb *redis.Client @@ -19,6 +47,16 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { return &gatewayCache{rdb: rdb} } +// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。 +// Pipeline 中的 Script.Run 只发送 EVALSHA,如果 Redis 重启过导致脚本缓存丢失, +// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。 +func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) { + exists, err := script.Exists(ctx, rdb).Result() + if err != nil || len(exists) == 0 || !exists[0] { + _ = script.Load(ctx, rdb).Err() + } +} + // buildSessionKey 构建 session key,包含 groupID 实现分组隔离 // 格式: sticky_session:{groupID}:{sessionHash} func buildSessionKey(groupID int64, sessionHash string) string { @@ -41,13 +79,218 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses } // DeleteSessionAccountID 删除粘性会话与账号的绑定关系。 -// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用, -// 以便下次请求能够重新选择可用账号。 -// -// DeleteSessionAccountID removes the sticky session binding for the given session. -// Called when the bound account becomes unavailable (e.g., error status, disabled, -// or unschedulable), allowing subsequent requests to select a new available account. func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } + +// buildAffinityKey 构建正向亲和 key(client → accounts) +// 格式: client_affinity:{groupID}:{clientID} +func buildAffinityKey(groupID int64, clientID string) string { + return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID) +} + +// buildAffinityReverseKey 构建反向亲和 key(account → clients) +// 格式: client_affinity_rev:{groupID}:{accountID} +func buildAffinityReverseKey(groupID int64, accountID int64) string { + return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID) +} + +func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) { + key := buildAffinityKey(groupID, clientID) + now := time.Now().Unix() + expireThreshold := now - int64(ttl.Seconds()) + + result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + + accountIDs := make([]int64, 0, len(result)) + for _, s := range result { + id, err := strconv.ParseInt(s, 10, 64) + if err != nil { + continue + } + accountIDs = append(accountIDs, id) + } + return accountIDs, nil +} + +func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error { + fwdKey := buildAffinityKey(groupID, clientID) + revKey := buildAffinityReverseKey(groupID, accountID) + now := time.Now().Unix() + ttlSeconds := int64(ttl.Seconds()) + expireThreshold := now - ttlSeconds + + return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey}, + now, ttlSeconds, accountID, expireThreshold, clientID, + ).Err() +} + +// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员) +func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) { + if len(accountIDs) == 0 { + return map[int64]int64{}, nil + } + + now := time.Now().Unix() + expireThreshold := now - int64(ttl.Seconds()) + + ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript) + + pipe := c.rdb.Pipeline() + cmds := make([]*redis.Cmd, len(accountIDs)) + for i, accID := range accountIDs { + key := buildAffinityReverseKey(groupID, accID) + cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold) + } + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return nil, err + } + + result := make(map[int64]int64, len(accountIDs)) + for i, accID := range accountIDs { + count, _ := cmds[i].Int64() + result[accID] = count + } + return result, nil +} + +// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。 +// accountGroups: map[accountID][]groupID,对每个 (groupID, accountID) 组合查询反向索引。 +func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) { + if len(accountGroups) == 0 { + return map[int64][]string{}, nil + } + + now := time.Now().Unix() + expireThreshold := now - int64(ttl.Seconds()) + + // 构建所有 (accountID, groupID) 组合的查询 + type queryItem struct { + accountID int64 + groupID int64 + } + var queries []queryItem + for accID, groupIDs := range accountGroups { + for _, gID := range groupIDs { + queries = append(queries, queryItem{accountID: accID, groupID: gID}) + } + } + + ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript) + + pipe := c.rdb.Pipeline() + cmds := make([]*redis.Cmd, len(queries)) + for i, q := range queries { + key := buildAffinityReverseKey(q.groupID, q.accountID) + cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold) + } + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return nil, err + } + + // 合并结果:同一个 accountID 跨多个 group 的 clientID 去重 + result := make(map[int64][]string, len(accountGroups)) + seen := make(map[int64]map[string]struct{}, len(accountGroups)) + for i, q := range queries { + clients, _ := cmds[i].StringSlice() + if len(clients) == 0 { + continue + } + if seen[q.accountID] == nil { + seen[q.accountID] = make(map[string]struct{}) + } + for _, clientID := range clients { + if _, exists := seen[q.accountID][clientID]; !exists { + seen[q.accountID][clientID] = struct{}{} + result[q.accountID] = append(result[q.accountID], clientID) + } + } + } + return result, nil +} + +// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。 +func (c *gatewayCache) GetAccountAffinityClientsWithScores( + ctx context.Context, + accountID int64, + groupIDs []int64, + ttl time.Duration, +) ([]service.AffinityClient, error) { + if len(groupIDs) == 0 { + return nil, nil + } + + now := time.Now().Unix() + expireThreshold := now - int64(ttl.Seconds()) + + ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript) + + pipe := c.rdb.Pipeline() + cmds := make([]*redis.Cmd, len(groupIDs)) + for i, gID := range groupIDs { + key := buildAffinityReverseKey(gID, accountID) + cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold) + } + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return nil, err + } + + // 合并跨组结果,同一 clientID 取最近的 lastActive + seen := make(map[string]int64) // clientID → max timestamp + for _, cmd := range cmds { + vals, _ := cmd.StringSlice() + // vals 格式: [clientID1, score1, clientID2, score2, ...] + for j := 0; j+1 < len(vals); j += 2 { + clientID := vals[j] + ts, _ := strconv.ParseInt(vals[j+1], 10, 64) + if existing, ok := seen[clientID]; !ok || ts > existing { + seen[clientID] = ts + } + } + } + + result := make([]service.AffinityClient, 0, len(seen)) + for clientID, ts := range seen { + result = append(result, service.AffinityClient{ + ClientID: clientID, + LastActive: time.Unix(ts, 0), + }) + } + + // 按最后活跃时间降序排序 + service.SortAffinityClients(result) + + return result, nil +} + +// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。 +// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端, +// 从每个客户端的正向索引中移除该账号,然后删除反向索引。 +func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + + ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript) + + pipe := c.rdb.Pipeline() + for _, gID := range groupIDs { + revKey := buildAffinityReverseKey(gID, accountID) + clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID) + } + _, err := pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return err + } + return nil +} diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index f9fd6742..5c18a438 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -65,9 +65,6 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { - panic("unexpected") -} func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } @@ -131,9 +128,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) { - panic("unexpected") -} func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) { panic("unexpected") } @@ -200,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { panic("unexpected") } -func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) { +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { panic("unexpected") } func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { @@ -216,29 +210,6 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS panic("unexpected") } -type userSubRepoStubForGroupUpdate struct { - userSubRepoNoop - getActiveSub *UserSubscription - getActiveErr error - called bool - calledUserID int64 - calledGroupID int64 -} - -func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) { - s.called = true - s.calledUserID = userID - s.calledGroupID = groupID - if s.getActiveErr != nil { - return nil, s.getActiveErr - } - if s.getActiveSub == nil { - return nil, ErrSubscriptionNotFound - } - clone := *s.getActiveSub - return &clone, nil -} - // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -431,49 +402,14 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) { existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} - userRepo := &userRepoStubForGroupUpdate{} - userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound} - svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} - - // 无有效订阅时应拒绝绑定 - _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) - require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err)) - require.True(t, userSubRepo.called) - require.Equal(t, int64(42), userSubRepo.calledUserID) - require.Equal(t, int64(10), userSubRepo.calledGroupID) - require.False(t, userRepo.addGroupCalled) -} - -func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) { - existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} - apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} userRepo := &userRepoStubForGroupUpdate{} svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} + // 订阅类型分组应被阻止绑定 _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err)) - require.False(t, userRepo.addGroupCalled) -} - -func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) { - existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} - apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} - userRepo := &userRepoStubForGroupUpdate{} - userSubRepo := &userSubRepoStubForGroupUpdate{ - getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10}, - } - svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} - - got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) - require.NoError(t, err) - require.True(t, userSubRepo.called) - require.NotNil(t, got.APIKey.GroupID) - require.Equal(t, int64(10), *got.APIKey.GroupID) + require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) require.False(t, userRepo.addGroupCalled) } diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index e88694f5..7f6c748f 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -46,12 +46,9 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } -func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { - return nil -} -func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } -func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } // --- mock: APIKeyAuthCacheInvalidator --- From 3de77130175138146981d2e8d34ec8eb19a614b9 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 30 Mar 2026 21:47:06 +0800 Subject: [PATCH 02/88] =?UTF-8?q?fix(channel):=20splice=E6=9B=BF=E6=8D=A2m?= =?UTF-8?q?odel=5Fpricing=E6=9D=A1=E7=9B=AE=20+=20=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E8=B0=83=E8=AF=95=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/VERSION | 2 +- frontend/src/views/admin/ChannelsView.vue | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 4b9b35d8..1ebb081e 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.112 +0.1.105.13 diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index ebfc1b5a..5c2f153b 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -970,6 +970,7 @@ async function handleSubmit() { } const { group_ids, model_pricing, model_mapping } = formToAPI() + console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing)) submitting.value = true try { From 2dce4306b4409e355f6ff265fa30ae7a2d3a6221 Mon Sep 17 00:00:00 2001 From: erio Date: Thu, 2 Apr 2026 13:24:30 +0800 Subject: [PATCH 03/88] refactor: move channel model restriction from handler to scheduling phase Move the model pricing restriction check from 8 handler entry points to the account scheduling phase (SelectAccountForModelWithExclusions / SelectAccountWithLoadAwareness), aligning restriction with billing: - requested: check original request model against pricing list - channel_mapped: check channel-mapped model against pricing list - upstream: per-account check using account-mapped model Handler layer now only resolves channel mapping (no restriction). Scheduling layer performs pre-check for requested/channel_mapped, and per-account filtering for upstream billing source. --- .../gateway_handler_chat_completions.go | 2 +- .../handler/gateway_handler_responses.go | 2 +- .../internal/handler/gemini_v1beta_handler.go | 14 +- .../handler/openai_chat_completions.go | 2 +- .../handler/openai_gateway_handler.go | 46 +- backend/internal/service/gateway_service.go | 1202 +++++++++++------ 6 files changed, 793 insertions(+), 475 deletions(-) diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index be267332..abe2a1e5 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // Claude Code only restriction diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index e908eb9e..cf877182 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // Claude Code only restriction: diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d200c17c..ff63bc7f 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusBadGateway, err.Error()) return } - if shouldFallbackGeminiModel(modelName, res) { + if shouldFallbackGeminiModels(res) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } @@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) reqModel := modelName // 保存映射前的原始模型名 if channelMapping.Mapped { @@ -682,16 +682,6 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { return false } -func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool { - if shouldFallbackGeminiModels(res) { - return true - } - if res == nil || res.StatusCode != http.StatusNotFound { - return false - } - return gemini.HasFallbackModel(modelName) -} - // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 991cbb91..ada401c9 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) if h.errorPassthroughService != nil { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 5319b55d..2b081617 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -47,13 +47,6 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode return strings.TrimSpace(apiKey.Group.DefaultMappedModel) } -func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { - if apiKey == nil || apiKey.Group == nil { - return "" - } - return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel)) -} - // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, @@ -557,8 +550,6 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { return } reqModel := modelResult.String() - routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel) - preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel) reqStream := gjson.GetBytes(body, "stream").Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) @@ -617,20 +608,17 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { failedAccountIDs := make(map[int64]struct{}) sameAccountRetryCount := make(map[int64]int) var lastFailoverErr *service.UpstreamFailoverError - effectiveMappedModel := preferredMappedModel for { - currentRoutingModel := routingModel - if effectiveMappedModel != "" { - currentRoutingModel = effectiveMappedModel - } + // 清除上一次迭代的降级模型标记,避免残留影响本次迭代 + c.Set("openai_messages_fallback_model", "") reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( c.Request.Context(), apiKey.GroupID, "", // no previous_response_id sessionHash, - currentRoutingModel, + reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, ) @@ -639,7 +627,29 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)), ) + // 首次调度失败 + 有默认映射模型 → 用默认模型重试 if len(failedAccountIDs) == 0 { + defaultModel := "" + if apiKey.Group != nil { + defaultModel = apiKey.Group.DefaultMappedModel + } + if defaultModel != "" && defaultModel != reqModel { + reqLog.Info("openai_messages.fallback_to_default_model", + zap.String("default_mapped_model", defaultModel), + ) + selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + defaultModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err == nil && selection != nil { + c.Set("openai_messages_fallback_model", defaultModel) + } + } if err != nil { h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return @@ -671,7 +681,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - defaultMappedModel := strings.TrimSpace(effectiveMappedModel) + // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 + // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 + defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) // 应用渠道模型映射到请求体 forwardBody := body if channelMappingMsg.Mapped { @@ -1106,7 +1118,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) var currentUserRelease func() diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8b0bdc2a..33ab38f2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,7 +12,6 @@ import ( "log/slog" mathrand "math/rand" "net/http" - "net/url" "os" "path/filepath" "regexp" @@ -42,7 +41,8 @@ import ( const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" - stickySessionTTL = time.Hour // 粘性会话TTL + stickySessionTTL = time.Hour // 粘性会话TTL + ClientAffinityTTL = 24 * time.Hour // 客户端亲和TTL defaultMaxLineSize = 500 * 1024 * 1024 // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) // to match real Claude CLI traffic as closely as possible. When we need a visual @@ -60,14 +60,28 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) +// MediaType 媒体类型常量 +const ( + MediaTypeImage = "image" + MediaTypeVideo = "video" + MediaTypePrompt = "prompt" +) + +const ( + claudeMaxMessageOverheadTokens = 3 + claudeMaxBlockOverheadTokens = 1 + claudeMaxUnknownContentTokens = 4 +) + // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} // accountWithLoad 账号与负载信息的组合,用于负载感知调度 type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo + account *Account + loadInfo *AccountLoadInfo + affinityCount int64 // 亲和客户端数量(反向索引),越少越优先 } var ForceCacheBillingContextKey = forceCacheBillingKeyType{} @@ -331,6 +345,10 @@ var ( sseDataRe = regexp.MustCompile(`^data:\s*`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) + // clientIDFromMetadataRegex 从 metadata.user_id 中提取客户端 ID(64位 hex) + // 格式: user_{64位hex}_account_... + clientIDFromMetadataRegex = regexp.MustCompile(`^user_([a-f0-9]{64})_account_`) + // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 注意:前缀之间不应存在包含关系,否则会导致冗余匹配 @@ -348,6 +366,12 @@ var ErrNoAvailableAccounts = errors.New("no available accounts") // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") +// ErrAffinityNoSwitch 表示亲和账号不可用且不允许切换到其他账号 +var ErrAffinityNoSwitch = errors.New("affinity account unavailable and switching is disabled") + +// ErrAffinityLimitExceeded 表示亲和客户端限制已达上限 +var ErrAffinityLimitExceeded = errors.New("affinity client limit exceeded") + // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -369,8 +393,6 @@ var allowedHeaders = map[string]bool{ "user-agent": true, "content-type": true, "accept-encoding": true, - "x-claude-code-session-id": true, - "x-client-request-id": true, } // GatewayCache 定义网关服务的缓存操作接口。 @@ -391,6 +413,39 @@ type GatewayCache interface { // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error + + // GetAffinityAccounts 获取亲和账号列表(按最近使用降序),同时清理过期成员 + GetAffinityAccounts(ctx context.Context, groupID int64, userID int64, clientID string, ttl time.Duration) ([]int64, error) + // UpdateAffinity 添加/更新亲和关系(更新 score 为当前时间戳,刷新 key TTL) + UpdateAffinity(ctx context.Context, groupID int64, userID int64, clientID string, accountID int64, ttl time.Duration) error + // GetAccountAffinityCountBatch 批量获取账号的亲和成员数量(惰性清理过期成员) + GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) + // GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和成员列表(去重) + // accountGroups: map[accountID][]groupID + // 返回值成员格式为 {userID}/{clientID} + GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) + // GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间) + GetAccountAffinityClientsWithScores(ctx context.Context, accountID int64, groupIDs []int64, ttl time.Duration) ([]AffinityClient, error) + // ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引) + // 用于账号关闭亲和时立即清理旧绑定 + ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error + // GetAffinityMultiCount 获取账号的多维度亲和计数 + // 返回: uniqueUsers, uniqueClients, perUserClients + GetAffinityMultiCount(ctx context.Context, groupID int64, accountID int64, targetUserID int64, ttl time.Duration) (users, clients, perUser int64, err error) +} + +// AffinityClient 亲和客户端信息(含用户 ID 和最后活跃时间) +type AffinityClient struct { + UserID int64 `json:"user_id"` + ClientID string `json:"client_id"` + LastActive time.Time `json:"last_active"` +} + +// SortAffinityClients 按最后活跃时间降序排序 +func SortAffinityClients(clients []AffinityClient) { + sort.Slice(clients, func(i, j int) bool { + return clients[i].LastActive.After(clients[j].LastActive) + }) } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -461,6 +516,20 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { return false } +// extractClientIDFromMetadata 从 metadata.user_id 中提取客户端 ID(64位 hex)。 +// 格式: user_{64位hex}_account_..._session_... +// 返回空字符串表示无法提取(非 Claude Code/Console 客户端)。 +func extractClientIDFromMetadata(metadataUserID string) string { + if metadataUserID == "" { + return "" + } + matches := clientIDFromMetadataRegex.FindStringSubmatch(metadataUserID) + if matches == nil { + return "" + } + return matches[1] +} + type AccountWaitPlan struct { AccountID int64 MaxConcurrency int @@ -504,6 +573,9 @@ type ForwardResult struct { ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" + // Sora 媒体字段 + MediaType string // image / video / prompt + MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -1162,6 +1234,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + // 渠道定价限制预检查(requested / channel_mapped 基准) + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 优先检查 context 中的强制平台(/antigravity 路由) var platform string forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) @@ -1180,32 +1257,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context platform = PlatformAnthropic } - // Claude Code 限制可能已将 groupID 解析为 fallback group, - // 渠道限制预检查必须使用解析后的分组。 - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - slog.Warn("channel pricing restriction blocked request", - "group_id", derefGroupID(groupID), - "model", requestedModel) - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { - account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) - if err != nil { - return nil, err - } - return s.hydrateSelectedAccount(ctx, account) + return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } // antigravity 分组、强制平台模式或无分组使用单平台选择 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 - account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) - if err != nil { - return nil, err - } - return s.hydrateSelectedAccount(ctx, account) + return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. @@ -1213,6 +1273,11 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // metadataUserID: 用于客户端亲和调度,从中提取客户端 ID // sub2apiUserID: 系统用户 ID,用于二维亲和调度 func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { + // 渠道定价限制预检查(requested / channel_mapped 基准) + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { @@ -1233,15 +1298,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) - // Claude Code 限制可能已将 groupID 解析为 fallback group, - // 渠道限制预检查必须使用解析后的分组。 - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - slog.Warn("channel pricing restriction blocked request", - "group_id", derefGroupID(groupID), - "model", requestedModel) - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - var stickyAccountID int64 if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch @@ -1251,6 +1307,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } + // 提取客户端 ID(用于客户端亲和调度) + affinityClientID := extractClientIDFromMetadata(metadataUserID) + affinityUserID := sub2apiUserID + if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { @@ -1272,6 +1332,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err != nil { return nil, err } + if shouldFilterAccountWithoutClientID(account, affinityClientID) { + localExcluded[account.ID] = struct{}{} + continue + } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) if err == nil && result.Acquired { @@ -1281,7 +1345,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro localExcluded[account.ID] = struct{}{} // 排除此账号 continue // 重新选择 } - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } // 对于等待计划的情况,也需要先检查会话限制 @@ -1293,20 +1361,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } } - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }) + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil } } @@ -1323,12 +1397,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err != nil { return nil, err } + accounts = filterAccountsWithoutClientID(accounts, affinityClientID) if len(accounts) == 0 { return nil, ErrNoAvailableAccounts } ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts) + // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用) + accountByID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + accountByID[accounts[i].ID] = &accounts[i] + } isExcluded := func(accountID int64) bool { if excludedIDs == nil { return false @@ -1336,12 +1416,19 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro _, excluded := excludedIDs[accountID] return excluded } - - // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用) - accountByID := make(map[int64]*Account, len(accounts)) - for i := range accounts { - accountByID[accounts[i].ID] = &accounts[i] - } + affinityFlow := newGatewayAffinityFlow( + s, + ctx, + groupID, + sessionHash, + requestedModel, + affinityClientID, + affinityUserID, + platform, + useMixed, + accountByID, + isExcluded, + ) // 获取模型路由配置(仅 anthropic 平台) var routingAccountIDs []int64 @@ -1430,76 +1517,53 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - var stickyCacheMissReason string - - gatePass := s.isAccountSchedulableForSelection(stickyAccount) && + if s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForQuota(stickyAccount) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && - rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) - - if rpmPass { // 粘性会话窗口费用+RPM 检查 + s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 - stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } - return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil) + return &AccountSelectionResult{ + Account: stickyAccount, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } } - if stickyCacheMissReason == "" { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - stickyCacheMissReason = "session_limit" - // 会话限制已满,继续到负载感知选择 - } else { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + // 会话限制已满,继续到负载感知选择 } else { - stickyCacheMissReason = "wait_queue_full" + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 - } else if !gatePass { - stickyCacheMissReason = "gate_check" - } else { - stickyCacheMissReason = "rpm_red" - } - - // 记录粘性缓存未命中的结构化日志 - if stickyCacheMissReason != "" { - baseRPM := stickyAccount.GetBaseRPM() - var currentRPM int - if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { - currentRPM = count - } - logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", - stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM) } } else { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", - stickyAccountID, shortSessionHash(sessionHash)) } } } @@ -1527,7 +1591,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } if len(routingAvailable) > 0 { - // 排序:优先级 > 负载率 > 最后使用时间 + // 批量获取亲和客户端数量 + s.populateAffinityCounts(ctx, routingAvailable, derefGroupID(groupID)) + + // 排序:优先级 > 负载率 > 亲和客户端数 > 最后使用时间 sort.SliceStable(routingAvailable, func(i, j int) bool { a, b := routingAvailable[i], routingAvailable[j] if a.account.Priority != b.account.Priority { @@ -1536,6 +1603,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if a.loadInfo.LoadRate != b.loadInfo.LoadRate { return a.loadInfo.LoadRate < b.loadInfo.LoadRate } + if a.affinityCount != b.affinityCount { + return a.affinityCount < b.affinityCount + } switch { case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: return true @@ -1561,10 +1631,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } + if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && item.account.IsAffinityEnabled() { + _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, item.account.ID, ClientAffinityTTL) + } if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil) + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } } @@ -1577,12 +1654,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{ - AccountID: item.account.ID, - MaxConcurrency: item.account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) + return &AccountSelectionResult{ + Account: item.account, + WaitPlan: &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } @@ -1591,14 +1671,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ - if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + // ============ Layer 1.3: 用户亲和预处理(pinned_users 自动注入) ============ + affinityFlow.preprocessPinnedUsers(accounts) + + // ============ Layer 1.4: 客户端亲和调度(优先于粘性会话) ============ + affinityHit := false + if affinityResult, hit, err := affinityFlow.trySelectAffinityAccount(); err != nil { + return nil, err + } else { + affinityHit = hit + if affinityResult != nil { + return affinityResult, nil + } + } + + // ============ Layer 1.5: 粘性会话(仅在无模型路由配置 且 亲和未命中时生效) ============ + if !affinityHit && len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { accountID := stickyAccountID if accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] if ok { // 检查账户是否需要清理粘性会话绑定 - // Check if the account needs sticky session cleanup clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) @@ -1614,31 +1707,32 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - // Session count limit check if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - if s.cache != nil { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) - } - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) if waitingCount < cfg.StickySessionMaxWaiting { // 会话数量限制检查(等待计划也需要占用会话配额) - // Session count limit check (wait plan also requires session quota) if !s.checkAndRegisterSession(ctx, account, sessionHash) { // 会话限制已满,继续到 Layer 2 - // Session limit full, continue to Layer 2 } else { - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } } } @@ -1697,9 +1791,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil { - return nil, legacyErr - } else if ok { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { + if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && result.Account != nil && result.Account.IsAffinityEnabled() { + _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, result.Account.ID, ClientAffinityTTL) + } return result, nil } } else { @@ -1717,13 +1812,37 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // 分层过滤选择:优先级 → 负载率 → LRU + // 批量获取亲和客户端数量(用于均衡分配新客户端) + s.populateAffinityCounts(ctx, available, derefGroupID(groupID)) + + // 分层过滤选择:优先级 → 亲和三区 → 负载率 → 亲和客户端数 → LRU for len(available) > 0 { // 1. 取优先级最小的集合 candidates := filterByMinPriority(available) - // 2. 取负载率最低的集合 + // 2. 按亲和三区过滤:绿区优先 → 黄区降级 → 红区移除(在同优先级内) + candidates = classifyByAffinityZone(candidates) + if len(candidates) == 0 { + // 当前优先级组全部在红区,移除后回退到下一优先级组 + minPri := available[0].account.Priority + for _, a := range available[1:] { + if a.account.Priority < minPri { + minPri = a.account.Priority + } + } + newAvailable := make([]accountWithLoad, 0, len(available)) + for _, a := range available { + if a.account.Priority != minPri { + newAvailable = append(newAvailable, a) + } + } + available = newAvailable + continue + } + // 3. 取负载率最低的集合 candidates = filterByMinLoadRate(candidates) - // 3. LRU 选择最久未用的账号 + // 3. 取亲和客户端数最少的集合 + candidates = filterByMinAffinityCount(candidates) + // 4. LRU 选择最久未用的账号 selected := selectByLRU(candidates, preferOAuth) if selected == nil { break @@ -1738,7 +1857,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } - return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil) + // 更新亲和关系 + if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && selected.account.IsAffinityEnabled() { + _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, selected.account.ID, ClientAffinityTTL) + } + return &AccountSelectionResult{ + Account: selected.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } } @@ -1761,17 +1888,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, acc, sessionHash) { continue // 会话限制已满,尝试下一个账号 } - return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }) + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil } return nil, ErrNoAvailableAccounts } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -1786,15 +1916,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } - selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil) - if err != nil { - return nil, false, err - } - return selection, true, nil + return &AccountSelectionResult{ + Account: acc, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, true } } - return nil, false, nil + return nil, false } func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -1939,6 +2069,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + if platform == PlatformSora { + return s.listSoraSchedulableAccounts(ctx, groupID) + } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -2035,6 +2168,53 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } +func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { + const useMixed = false + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } else if groupID != nil { + accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "error", err) + return nil, useMixed, err + } + + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform != PlatformSora { + continue + } + if !s.isSoraAccountSchedulable(&acc) { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_sora", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil +} + // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -2059,10 +2239,33 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } +func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { + return s.soraUnschedulableReason(account) == "" +} + +func (s *GatewayService) soraUnschedulableReason(account *Account) string { + if account == nil { + return "account_nil" + } + if account.Status != StatusActive { + return fmt.Sprintf("status=%s", account.Status) + } + if !account.Schedulable { + return "schedulable=false" + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) + } + return "" +} + func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { if account == nil { return false } + if account.Platform == PlatformSora { + return s.isSoraAccountSchedulable(account) + } return account.IsSchedulable() } @@ -2070,6 +2273,12 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte if account == nil { return false } + if account.Platform == PlatformSora { + if !s.isSoraAccountSchedulable(account) { + return false + } + return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 + } return account.IsSchedulableForModelWithContext(ctx, requestedModel) } @@ -2409,31 +2618,34 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } -func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { - if account == nil || s.schedulerSnapshot == nil { - return account, nil +// populateAffinityCounts 批量获取账号的亲和客户端数量并填入 accountWithLoad 切片。 +// 仅当存在开启了客户端亲和的账号时才查询 Redis,否则跳过。 +func (s *GatewayService) populateAffinityCounts(ctx context.Context, accounts []accountWithLoad, groupID int64) { + if s.cache == nil || len(accounts) == 0 { + return } - hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + // 快速检查:是否有任何账号开启了亲和 + hasAffinity := false + for _, acc := range accounts { + if acc.account.IsAffinityEnabled() { + hasAffinity = true + break + } + } + if !hasAffinity { + return + } + accountIDs := make([]int64, len(accounts)) + for i, acc := range accounts { + accountIDs[i] = acc.account.ID + } + countMap, err := s.cache.GetAccountAffinityCountBatch(ctx, groupID, accountIDs, ClientAffinityTTL) if err != nil { - return nil, err + return // 查询失败不影响调度,affinityCount 保持 0 } - if hydrated == nil { - return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID) + for i := range accounts { + accounts[i].affinityCount = countMap[accounts[i].account.ID] } - return hydrated, nil -} - -func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { - hydrated, err := s.hydrateSelectedAccount(ctx, account) - if err != nil { - return nil, err - } - return &AccountSelectionResult{ - Account: hydrated, - Acquired: acquired, - ReleaseFunc: release, - WaitPlan: waitPlan, - }, nil } // filterByMinPriority 过滤出优先级最小的账号集合 @@ -2476,6 +2688,64 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { return result } +// filterByMinAffinityCount 过滤出亲和客户端数最少的账号集合 +func filterByMinAffinityCount(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts + } + minCount := accounts[0].affinityCount + for _, acc := range accounts[1:] { + if acc.affinityCount < minCount { + minCount = acc.affinityCount + } + } + result := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + if acc.affinityCount == minCount { + result = append(result, acc) + } + } + return result +} + +// classifyByAffinityZone 按亲和分区对候选账号进行分类。 +// 返回值:仅绿区账号(有绿区时),否则返回黄区账号。红区账号被移除。 +// 如果没有任何账号开启了亲和三区配置(即 affinity_base <= 0),则原样返回所有账号。 +func classifyByAffinityZone(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts + } + // 快速检查:是否有任何账号配置了 affinity_base + hasZoneConfig := false + for _, acc := range accounts { + if acc.account.IsAffinityEnabled() && acc.account.GetAffinityBase() > 0 { + hasZoneConfig = true + break + } + } + if !hasZoneConfig { + return accounts + } + + greens := make([]accountWithLoad, 0, len(accounts)) + yellows := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + zone := acc.account.GetAffinityZone(acc.affinityCount) + switch zone { + case AffinityZoneGreen: + greens = append(greens, acc) + case AffinityZoneYellow: + yellows = append(yellows, acc) + case AffinityZoneRed: + // 红区:移除,不参与调度 + } + } + if len(greens) > 0 { + return greens + } + return yellows +} + // selectByLRU 从集合中选择最久未用的账号 // 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { @@ -2711,12 +2981,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) - // require_privacy_set: 获取分组信息 - var schedGroup *Group - if groupID != nil && s.groupRepo != nil { - schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) - } - var accounts []Account accountsLoaded := false @@ -2788,12 +3052,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2885,8 +3143,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查, - // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -2899,12 +3155,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2971,12 +3221,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) - // require_privacy_set: 获取分组信息 - var schedGroup *Group - if groupID != nil && s.groupRepo != nil { - schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) - } - var accounts []Account accountsLoaded := false @@ -3044,12 +3288,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3143,7 +3381,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -3156,12 +3393,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3273,6 +3504,9 @@ func (s *GatewayService) logDetailedSelectionFailure( stats.SampleMappingIDs, stats.SampleRateLimitIDs, ) + if platform == PlatformSora { + s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) + } return stats } @@ -3329,7 +3563,11 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "excluded"} } if !s.isAccountSchedulableForSelection(acc) { - return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} + detail := "generic_unschedulable" + if acc.Platform == PlatformSora { + detail = s.soraUnschedulableReason(acc) + } + return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} } if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { return selectionFailureDiagnosis{ @@ -3353,6 +3591,57 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "eligible"} } +func (s *GatewayService) logSoraSelectionFailureDetails( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) { + const maxLines = 30 + logged := 0 + + for i := range accounts { + if logged >= maxLines { + break + } + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) + if diagnosis.Category == "eligible" { + continue + } + detail := diagnosis.Detail + if detail == "" { + detail = "-" + } + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + acc.ID, + acc.Platform, + diagnosis.Category, + detail, + ) + logged++ + } + if len(accounts) > maxLines { + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + len(accounts), + logged, + ) + } +} + func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { if acc == nil { return true @@ -3431,10 +3720,17 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } + if account.Platform == PlatformSora { + return s.isSoraModelSupportedByAccount(account, requestedModel) + } if account.IsBedrock() { _, ok := ResolveBedrockModelID(account, requestedModel) return ok } + // OpenAI 透传模式:仅替换认证,允许所有模型 + if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() { + return true + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) @@ -3443,6 +3739,143 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } +func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { + if account == nil { + return false + } + if strings.TrimSpace(requestedModel) == "" { + return true + } + + // 先走原始精确/通配符匹配。 + mapping := account.GetModelMapping() + if len(mapping) == 0 || account.IsModelSupported(requestedModel) { + return true + } + + aliases := buildSoraModelAliases(requestedModel) + if len(aliases) == 0 { + return false + } + + hasSoraSelector := false + for pattern := range mapping { + if !isSoraModelSelector(pattern) { + continue + } + hasSoraSelector = true + if matchPatternAnyAlias(pattern, aliases) { + return true + } + } + + // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), + // 此时不应误拦截 Sora 模型请求。 + if !hasSoraSelector { + return true + } + + return false +} + +func matchPatternAnyAlias(pattern string, aliases []string) bool { + normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) + if normalizedPattern == "" { + return false + } + for _, alias := range aliases { + if matchWildcard(normalizedPattern, alias) { + return true + } + } + return false +} + +func isSoraModelSelector(pattern string) bool { + p := strings.ToLower(strings.TrimSpace(pattern)) + if p == "" { + return false + } + + switch { + case strings.HasPrefix(p, "sora"), + strings.HasPrefix(p, "gpt-image"), + strings.HasPrefix(p, "prompt-enhance"), + strings.HasPrefix(p, "sy_"): + return true + } + + return p == "video" || p == "image" +} + +func buildSoraModelAliases(requestedModel string) []string { + modelID := strings.ToLower(strings.TrimSpace(requestedModel)) + if modelID == "" { + return nil + } + + aliases := make([]string, 0, 8) + addAlias := func(value string) { + v := strings.ToLower(strings.TrimSpace(value)) + if v == "" { + return + } + for _, existing := range aliases { + if existing == v { + return + } + } + aliases = append(aliases, v) + } + + addAlias(modelID) + cfg, ok := GetSoraModelConfig(modelID) + if ok { + addAlias(cfg.Model) + switch cfg.Type { + case "video": + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case "image": + addAlias("image") + addAlias("gpt-image") + case "prompt_enhance": + addAlias("prompt-enhance") + } + return aliases + } + + switch { + case strings.HasPrefix(modelID, "sora"): + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case strings.HasPrefix(modelID, "gpt-image"): + addAlias("image") + addAlias("gpt-image") + case strings.HasPrefix(modelID, "prompt-enhance"): + addAlias("prompt-enhance") + default: + return nil + } + + return aliases +} + +func soraVideoFamilyAlias(modelID string) string { + switch { + case strings.HasPrefix(modelID, "sora2pro-hd"): + return "sora2pro-hd" + case strings.HasPrefix(modelID, "sora2pro"): + return "sora2pro" + case strings.HasPrefix(modelID, "sora2"): + return "sora2" + default: + return "" + } +} + // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -3719,86 +4152,6 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { return result } -// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages, -// system 字段仅保留 Claude Code 标识提示词。 -// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词 -// 无法通过检测,因为后续内容仍为非 Claude Code 格式。 -// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。 -func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { - system = normalizeSystemParam(system) - - // 1. 提取原始 system prompt 文本 - var originalSystemText string - switch v := system.(type) { - case string: - originalSystemText = strings.TrimSpace(v) - case []any: - var parts []string - for _, item := range v { - if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" { - parts = append(parts, text) - } - } - } - originalSystemText = strings.Join(parts, "\n\n") - } - - // 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致) - // 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。 - // 使用 string 格式会被 Anthropic 检测为第三方应用。 - claudeCodeSystemBlock := []map[string]any{ - { - "type": "text", - "text": claudeCodeSystemPrompt, - "cache_control": map[string]string{"type": "ephemeral"}, - }, - } - out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock) - if !ok { - logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") - return body - } - - // 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头 - // 模型仍通过 messages 接收完整指令,保留客户端功能 - ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt) - if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) { - instrMsg, err1 := json.Marshal(map[string]any{ - "role": "user", - "content": []map[string]any{ - {"type": "text", "text": "[System Instructions]\n" + originalSystemText}, - }, - }) - ackMsg, err2 := json.Marshal(map[string]any{ - "role": "assistant", - "content": []map[string]any{ - {"type": "text", "text": "Understood. I will follow these instructions."}, - }, - }) - if err1 != nil || err2 != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection") - return out - } - - // 重建 messages 数组:[instruction, ack, ...originalMessages] - items := [][]byte{instrMsg, ackMsg} - messagesResult := gjson.GetBytes(out, "messages") - if messagesResult.IsArray() { - messagesResult.ForEach(func(_, msg gjson.Result) bool { - items = append(items, []byte(msg.Raw)) - return true - }) - } - - if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk { - out = next - } - } - - return out -} - type cacheControlPath struct { path string log string @@ -3960,7 +4313,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Always overwrite the cache to prevent stale values from a previous retry with a different account. if account.Platform == PlatformAnthropic && c != nil { - policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model) + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) if policy.blockErr != nil { return nil, policy.blockErr } @@ -3990,24 +4343,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { - // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages + // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 - systemRewritten := false if !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemIncludesClaudeCodePrompt(parsed.System) { - body = rewriteSystemForNonClaudeCode(body, parsed.System) - systemRewritten = true + body = injectClaudeCodePrompt(body, parsed.System) } - // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为); - // 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。 - // 两种情况下 enforceCacheControlLimit 都会兜底处理上限。 - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} if s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil && fp != nil { // metadata 透传开启时跳过 metadata 注入 - _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx) + _, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx) if !mimicMPT { if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { normalizeOpts.injectMetadata = true @@ -4054,12 +4402,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, err } - // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) + // 获取代理URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { - proxyURL = account.Proxy.URL() - } + proxyURL = account.Proxy.URL() } // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) @@ -4468,6 +4814,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 + ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed) // 触发上游接受回调(提前释放串行锁,不等流完成) if parsed.OnUpstreamAccepted != nil { @@ -5534,16 +5881,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } targetURL = validatedURL + "/v1/messages?beta=true" } - } else if account.IsCustomBaseURLEnabled() { - customURL := account.GetCustomBaseURL() - if customURL == "" { - return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) - } - validatedURL, err := s.validateUpstreamBaseURL(customURL) - if err != nil { - return nil, err - } - targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account) } clientHeaders := http.Header{} @@ -5553,9 +5890,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) var fingerprint *Fingerprint - enableFP, enableMPT, enableCCH := true, false, false + enableFP, enableMPT := true, false if s.settingService != nil { - enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx) + enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx) } if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) @@ -5582,15 +5919,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } - // 同步 billing header cc_version 与实际发送的 User-Agent 版本 - if fingerprint != nil { - body = syncBillingHeaderVersion(body, fingerprint.UserAgent) - } - // CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后) - if enableCCH { - body = signBillingHeaderCCH(body) - } - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -5631,8 +5959,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // Build effective drop set: merge static defaults with dynamic beta policy filter rules - policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID) + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) effectiveDropSet := mergeDropSets(policyFilterSet) + effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { @@ -5643,16 +5972,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex applyClaudeCodeMimicHeaders(req, reqStream) incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - // Claude Code OAuth credentials are scoped to Claude Code. - // Non-haiku models MUST include claude-code beta for Anthropic to recognize - // this as a legitimate Claude Code request; without it, the request is - // rejected as third-party ("out of extra usage"). - // Haiku models are exempt from third-party detection and don't need it. + // Match real Claude CLI traffic (per mitmproxy reports): + // messages requests typically use only oauth + interleaved-thinking. + // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - if !strings.Contains(strings.ToLower(modelID), "haiku") { - requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking} - } - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet)) + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") @@ -5672,15 +5996,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } - } - } - // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ "url": req.URL.String(), @@ -5875,7 +6190,7 @@ type betaPolicyResult struct { } // evaluateBetaPolicy loads settings once and evaluates all rules against the given request. -func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult { +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { if s.settingService == nil { return betaPolicyResult{} } @@ -5890,11 +6205,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } - effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) - switch effectiveAction { + switch rule.Action { case BetaPolicyActionBlock: if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { - msg := effectiveErrMsg + msg := rule.ErrorMessage if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -5936,7 +6250,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet" // In the /v1/messages path, Forward() evaluates the policy first and caches the result; // buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this // evaluates on demand (one DB call). -func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} { +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { if c != nil { if v, ok := c.Get(betaPolicyFilterSetKey); ok { if fs, ok := v.(map[string]struct{}); ok { @@ -5944,7 +6258,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont } } } - return s.evaluateBetaPolicy(ctx, "", account, model).filterSet + return s.evaluateBetaPolicy(ctx, "", account).filterSet } // betaPolicyScopeMatches checks whether a rule's scope matches the current account type. @@ -5963,33 +6277,6 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { } } -// matchModelWhitelist checks if a model matches any pattern in the whitelist. -// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching. -func matchModelWhitelist(model string, whitelist []string) bool { - for _, pattern := range whitelist { - if matchModelPattern(pattern, model) { - return true - } - } - return false -} - -// resolveRuleAction determines the effective action and error message for a rule given the request model. -// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally. -// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others. -func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) { - if len(rule.ModelWhitelist) == 0 { - return rule.Action, rule.ErrorMessage - } - if matchModelWhitelist(model, rule.ModelWhitelist) { - return rule.Action, rule.ErrorMessage - } - if rule.FallbackAction != "" { - return rule.FallbackAction, rule.FallbackErrorMessage - } - return BetaPolicyActionPass, "" // default fallback: pass (fail-open) -} - // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) @@ -6036,7 +6323,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( modelID string, ) ([]string, error) { // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) - policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account) if policy.blockErr != nil { return nil, policy.blockErr } @@ -6048,7 +6335,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → // 如果不做此检查,block 规则会被绕过。 - if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil { + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { return nil, blockErr } @@ -6057,7 +6344,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 // 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 -func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError { +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { if s.settingService == nil || len(tokens) == 0 { return nil } @@ -6069,15 +6356,14 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke isBedrock := account.IsBedrock() tokenSet := buildBetaTokenSet(tokens) for _, rule := range settings.Rules { - effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) - if effectiveAction != BetaPolicyActionBlock { + if rule.Action != BetaPolicyActionBlock { continue } if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } if _, present := tokenSet[rule.BetaToken]; present { - msg := effectiveErrMsg + msg := rule.ErrorMessage if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -6709,6 +6995,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage sawTerminalEvent := false + skipAccountTTLOverride := false pendingEventLines := make([]string, 0, 4) @@ -6770,17 +7057,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged + claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + skipAccountTTLOverride = true + } } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged + claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + skipAccountTTLOverride = true + } } } // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride { overrideTarget := account.GetCacheTTLOverrideTarget() if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { @@ -7212,8 +7507,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + body = rewriteClaudeUsageJSONBytes(body, response.Usage) + } + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated { overrideTarget := account.GetCacheTTLOverrideTarget() if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 @@ -7279,6 +7579,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult + ParsedRequest *ParsedRequest APIKey *APIKey User *User Account *Account @@ -7437,6 +7738,9 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.CacheCreationTokens = usageLog.CacheCreationTokens cmd.CacheReadTokens = usageLog.CacheReadTokens cmd.ImageCount = usageLog.ImageCount + if usageLog.MediaType != nil { + cmd.MediaType = *usageLog.MediaType + } if usageLog.ServiceTier != nil { cmd.ServiceTier = *usageLog.ServiceTier } @@ -7592,6 +7896,8 @@ type recordUsageOpts struct { // EnableClaudePath 启用 Claude 路径特有逻辑: // - Claude Max 缓存计费策略 + // - Sora 媒体类型分支(image/video/prompt) + // - MediaType 字段写入使用日志 EnableClaudePath bool // 长上下文计费(仅 Gemini 路径需要) @@ -7616,6 +7922,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu APIKeyService: input.APIKeyService, ChannelUsageFields: input.ChannelUsageFields, }, &recordUsageOpts{ + ParsedRequest: input.ParsedRequest, EnableClaudePath: true, }) } @@ -7682,6 +7989,7 @@ type recordUsageCoreInput struct { // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // opts 中的字段控制两者之间的差异行为: // - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 +// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt) // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result @@ -7699,9 +8007,21 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage result.Usage.InputTokens = 0 } - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + // Claude Max cache billing policy(仅 Claude 路径启用) cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { + simulatedClaudeMax := false + if opts.EnableClaudePath { + var apiKeyGroup *Group + if apiKey != nil { + apiKeyGroup = apiKey.Group + } + claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID) + simulatedClaudeMax = claudeMaxOutcome.Simulated || + (shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage)) + } + + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 } @@ -7783,6 +8103,16 @@ func (s *GatewayService) calculateRecordUsageCost( multiplier float64, opts *recordUsageOpts, ) *CostBreakdown { + // Sora 媒体类型分支(仅 Claude 路径启用) + if opts.EnableClaudePath { + if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo { + return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier) + } + if result.MediaType == MediaTypePrompt { + return &CostBreakdown{} + } + } + // 图片生成计费 if result.ImageCount > 0 { return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) @@ -7792,6 +8122,28 @@ func (s *GatewayService) calculateRecordUsageCost( return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) } +// calculateSoraMediaCost 计算 Sora 图片/视频的费用。 +func (s *GatewayService) calculateSoraMediaCost( + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, +) *CostBreakdown { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == MediaTypeImage { + return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } + return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) +} + // resolveChannelPricing 检查指定模型是否存在渠道级别定价。 // 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { @@ -7814,7 +8166,7 @@ func (s *GatewayService) calculateImageCost( billingModel string, multiplier float64, ) *CostBreakdown { - if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -7829,7 +8181,6 @@ func (s *GatewayService) calculateImageCost( RequestCount: 1, RateMultiplier: multiplier, Resolver: s.resolver, - Resolved: resolved, }) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) @@ -7872,7 +8223,7 @@ func (s *GatewayService) calculateTokenCost( var err error // 优先尝试渠道定价 → CalculateCostUnified - if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { gid := apiKey.Group.ID cost, err = s.billingService.CalculateCostUnified(CostInput{ Ctx: ctx, @@ -7882,7 +8233,6 @@ func (s *GatewayService) calculateTokenCost( RequestCount: 1, RateMultiplier: multiplier, Resolver: s.resolver, - Resolved: resolved, }) } else if opts.LongContextThreshold > 0 { // 长上下文双倍计费(如 Gemini 200K 阈值) @@ -7940,12 +8290,13 @@ func (s *GatewayService) buildRecordUsageLog( RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, - BillingMode: resolveBillingMode(result, cost), + BillingMode: resolveBillingMode(opts, result, cost), Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: optionalTrimmedStringPtr(result.ImageSize), + MediaType: resolveMediaType(opts, result), CacheTTLOverridden: cacheTTLOverridden, ChannelID: optionalInt64Ptr(input.ChannelID), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), @@ -7969,7 +8320,13 @@ func (s *GatewayService) buildRecordUsageLog( } // resolveBillingMode 根据计费结果和请求类型确定计费模式。 -func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { +// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。 +func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { + isSoraMedia := opts.EnableClaudePath && + (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) + if isSoraMedia { + return nil + } var mode string switch { case cost != nil && cost.BillingMode != "": @@ -7982,6 +8339,13 @@ func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { return &mode } +func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { + if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { + return &result.MediaType + } + return nil +} + func optionalSubscriptionID(subscription *UserSubscription) *int64 { if subscription != nil { return &subscription.ID @@ -8010,8 +8374,8 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m return s.channelService.IsModelRestricted(ctx, groupID, model) } -// ResolveChannelMappingAndRestrict 解析渠道映射。 -// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。 +// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 +// 返回映射结果和是否被限制。 func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { if s.channelService == nil { return ChannelMappingResult{MappedModel: model}, false @@ -8042,9 +8406,7 @@ func billingModelForRestriction(source, requestedModel, channelMappedModel strin return requestedModel case BillingModelSourceUpstream: return "" - case BillingModelSourceChannelMapped: - return channelMappedModel - default: + default: // channel_mapped return channelMappedModel } } @@ -8076,11 +8438,7 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex return false } ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) - if err != nil { - slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err) - return false - } - if ch == nil || !ch.RestrictModels { + if err != nil || ch == nil || !ch.RestrictModels { return false } return ch.BillingModelSource == BillingModelSourceUpstream @@ -8172,12 +8530,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return err } - // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) + // 获取代理URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { - proxyURL = account.Proxy.URL() - } + proxyURL = account.Proxy.URL() } // 发送请求 @@ -8456,16 +8812,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } - } else if account.IsCustomBaseURLEnabled() { - customURL := account.GetCustomBaseURL() - if customURL == "" { - return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) - } - validatedURL, err := s.validateUpstreamBaseURL(customURL) - if err != nil { - return nil, err - } - targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account) } clientHeaders := http.Header{} @@ -8475,9 +8821,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false + ctEnableFP, ctEnableMPT := true, false if s.settingService != nil { - ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx) + ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx) } var ctFingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { @@ -8495,14 +8841,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // 同步 billing header cc_version 与实际发送的 User-Agent 版本 - if ctFingerprint != nil && ctEnableFP { - body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent) - } - if ctEnableCCH { - body = signBillingHeaderCCH(body) - } - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -8543,7 +8881,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules - ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID)) + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { @@ -8579,15 +8917,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } - } - } - if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } @@ -8609,19 +8938,6 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m }) } -// buildCustomRelayURL 构建自定义中继转发 URL -// 在 path 后附加 beta=true 和可选的 proxy 查询参数 -func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string { - u := strings.TrimRight(baseURL, "/") + path + "?beta=true" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL := account.Proxy.URL() - if proxyURL != "" { - u += "&proxy=" + url.QueryEscape(proxyURL) - } - } - return u -} - func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) From 160903fce7841b99d761b30372f793bcbf448c5f Mon Sep 17 00:00:00 2001 From: erio Date: Thu, 2 Apr 2026 13:36:58 +0800 Subject: [PATCH 04/88] fix: address review findings for channel restriction refactoring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix 7 stale comments still mentioning "限制检查" in handlers/services - Make billingModelForRestriction explicitly list channel_mapped case - Add slog.Warn for error swallowing in ResolveChannelMapping and needsUpstreamChannelRestrictionCheck - Document sticky session upstream check exemption --- .../handler/gateway_handler_chat_completions.go | 2 +- .../handler/gateway_handler_responses.go | 2 +- .../internal/handler/gemini_v1beta_handler.go | 2 +- .../internal/handler/openai_chat_completions.go | 2 +- .../internal/handler/openai_gateway_handler.go | 2 +- backend/internal/service/gateway_service.go | 17 +++++++++++++---- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index abe2a1e5..be267332 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + 限制检查 + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // Claude Code only restriction diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index cf877182..e908eb9e 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + 限制检查 + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // Claude Code only restriction: diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index ff63bc7f..45b5842f 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) - // 解析渠道级模型映射 + 限制检查 + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) reqModel := modelName // 保存映射前的原始模型名 if channelMapping.Mapped { diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index ada401c9..991cbb91 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + 限制检查 + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) if h.errorPassthroughService != nil { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 2b081617..dda6d2e3 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -1118,7 +1118,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) - // 解析渠道级模型映射 + 限制检查 + // 解析渠道级模型映射 channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) var currentUserRelease func() diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 33ab38f2..24f36113 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3143,6 +3143,8 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查, + // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -3381,6 +3383,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -8374,8 +8377,8 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m return s.channelService.IsModelRestricted(ctx, groupID, model) } -// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 -// 返回映射结果和是否被限制。 +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。 func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { if s.channelService == nil { return ChannelMappingResult{MappedModel: model}, false @@ -8406,7 +8409,9 @@ func billingModelForRestriction(source, requestedModel, channelMappedModel strin return requestedModel case BillingModelSourceUpstream: return "" - default: // channel_mapped + case BillingModelSourceChannelMapped: + return channelMappedModel + default: return channelMappedModel } } @@ -8438,7 +8443,11 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex return false } ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) - if err != nil || ch == nil || !ch.RestrictModels { + if err != nil { + slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err) + return false + } + if ch == nil || !ch.RestrictModels { return false } return ch.BillingModelSource == BillingModelSourceUpstream From e3748741257c2c6b96ced2c0c0284dd31cc5e358 Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 3 Apr 2026 13:54:18 +0800 Subject: [PATCH 05/88] feat(channel): improve cache strategy and add restriction logging - Change channel cache TTL from 60s to 10min (reduce unnecessary DB queries) - Actively rebuild cache after CRUD instead of lazy invalidation - Add slog.Warn logging for channel pricing restriction blocks (4 places) --- backend/internal/service/channel_service.go | 378 ++++++++------------ backend/internal/service/gateway_service.go | 46 ++- 2 files changed, 183 insertions(+), 241 deletions(-) diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 9667cb98..c6a249ef 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan const ( channelCacheTTL = 10 * time.Minute - channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 + channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 channelCacheDBTimeout = 10 * time.Second ) @@ -197,8 +197,10 @@ func newEmptyChannelCache() *channelCache { } // expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。 -// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。 -// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。 +// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。 +// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台, +// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。 +// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。 func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { for j := range ch.ModelPricing { pricing := &ch.ModelPricing[j] @@ -224,7 +226,8 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform } // expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。 -// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。 +// antigravity 平台同时服务 Claude 和 Gemini 模型。 +// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。 func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { for _, mappingPlatform := range matchingPlatforms(platform) { platformMapping, ok := ch.ModelMapping[mappingPlatform] @@ -248,58 +251,40 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform } } -// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。 -// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。 -func (s *ChannelService) storeErrorCache() { - errorCache := newEmptyChannelCache() - errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) - s.cache.Store(errorCache) -} - // buildCache 从数据库构建渠道缓存。 // 使用独立 context 避免请求取消导致空值被长期缓存。 func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { + // 断开请求取消链,避免客户端断连导致空值被长期缓存 dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout) defer cancel() - channels, groupPlatforms, err := s.fetchChannelData(dbCtx) - if err != nil { - return nil, err - } - - cache := populateChannelCache(channels, groupPlatforms) - s.cache.Store(cache) - return cache, nil -} - -// fetchChannelData 从数据库加载渠道列表和分组平台映射。 -func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) { - channels, err := s.repo.ListAll(ctx) + channels, err := s.repo.ListAll(dbCtx) if err != nil { + // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 slog.Warn("failed to build channel cache", "error", err) - s.storeErrorCache() - return nil, nil, fmt.Errorf("list all channels: %w", err) + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL + s.cache.Store(errorCache) + return nil, fmt.Errorf("list all channels: %w", err) } + // 收集所有 groupID,批量查询 platform var allGroupIDs []int64 for i := range channels { allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...) } - groupPlatforms := make(map[int64]string) if len(allGroupIDs) > 0 { - groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs) + groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs) if err != nil { slog.Warn("failed to load group platforms for channel cache", "error", err) - s.storeErrorCache() - return nil, nil, fmt.Errorf("get group platforms: %w", err) + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) + s.cache.Store(errorCache) + return nil, fmt.Errorf("get group platforms: %w", err) } } - return channels, groupPlatforms, nil -} -// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。 -func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache { cache := newEmptyChannelCache() cache.groupPlatform = groupPlatforms cache.byID = make(map[int64]*Channel, len(channels)) @@ -308,6 +293,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) * for i := range channels { ch := &channels[i] cache.byID[ch.ID] = ch + for _, gid := range ch.GroupIDs { cache.channelByGroupID[gid] = ch platform := groupPlatforms[gid] @@ -315,20 +301,33 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) * expandMappingToCache(cache, ch, gid, platform) } } - return cache + + // 通配符条目保持配置顺序(最先匹配到优先) + + s.cache.Store(cache) + return cache, nil } // invalidateCache 使缓存失效,让下次读取时自然重建 // isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。 -// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。 +// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型, +// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。 func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool { - return groupPlatform == pricingPlatform + if groupPlatform == pricingPlatform { + return true + } + if groupPlatform == PlatformAntigravity { + return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini + } + return false } -// matchingPlatforms 返回分组平台对应的可匹配平台列表。 -// 各平台严格独立,只返回自身。 +// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。 func matchingPlatforms(groupPlatform string) []string { + if groupPlatform == PlatformAntigravity { + return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini} + } return []string{groupPlatform} } func (s *ChannelService) invalidateCache() { @@ -365,8 +364,10 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower return "" } -// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。 -// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。 +// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。 +// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试 +// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini), +// 返回第一个命中的结果。非 antigravity 平台只尝试自身。 func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing { for _, p := range matchingPlatforms(groupPlatform) { key := channelModelKey{groupID: groupID, platform: p, model: modelLower} @@ -383,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf return nil } -// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。 +// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。 // 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。 func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string { for _, p := range matchingPlatforms(groupPlatform) { @@ -441,7 +442,8 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) } // GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。 -// 各平台严格独立,只在本平台内查找定价。 +// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini), +// 确保跨平台同名模型各自独立匹配。 func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { lk, err := s.lookupGroupChannel(ctx, groupID) if err != nil { @@ -479,10 +481,7 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 // 返回 true 表示模型被限制(不在允许列表中)。 // 如果渠道未启用模型限制或分组无渠道关联,返回 false。 func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { - lk, err := s.lookupGroupChannel(ctx, groupID) - if err != nil { - slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err) - } + lk, _ := s.lookupGroupChannel(ctx, groupID) if lk == nil { return false } @@ -525,7 +524,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi } // checkRestricted 基于已查找的渠道信息检查模型是否被限制。 -// 只在本平台的定价列表中查找。 +// antigravity 分组依次尝试所有匹配平台的定价列表。 func checkRestricted(lk *channelLookup, groupID int64, model string) bool { if !lk.channel.RestrictModels { return false @@ -553,91 +552,6 @@ func ReplaceModelInBody(body []byte, newModel string) []byte { return newBody } -// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。 -// Create 和 Update 共用此函数,避免重复。 -func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error { - if err := validateNoConflictingModels(pricing); err != nil { - return err - } - if err := validatePricingIntervals(pricing); err != nil { - return err - } - if err := validateNoConflictingMappings(mapping); err != nil { - return err - } - return validatePricingBillingMode(pricing) -} - -// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。 -func validatePricingBillingMode(pricing []ChannelModelPricing) error { - for _, p := range pricing { - if err := checkBillingModeRequirements(p); err != nil { - return err - } - if err := checkPricesNotNegative(p); err != nil { - return err - } - if err := checkIntervalsHavePrices(p); err != nil { - return err - } - } - return nil -} - -func checkBillingModeRequirements(p ChannelModelPricing) error { - if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage { - if p.PerRequestPrice == nil && len(p.Intervals) == 0 { - return infraerrors.BadRequest( - "BILLING_MODE_MISSING_PRICE", - "per-request price or intervals required for per_request/image billing mode", - ) - } - } - return nil -} - -func checkPricesNotNegative(p ChannelModelPricing) error { - checks := []struct { - field string - val *float64 - }{ - {"input_price", p.InputPrice}, - {"output_price", p.OutputPrice}, - {"cache_write_price", p.CacheWritePrice}, - {"cache_read_price", p.CacheReadPrice}, - {"image_output_price", p.ImageOutputPrice}, - {"per_request_price", p.PerRequestPrice}, - } - for _, c := range checks { - if c.val != nil && *c.val < 0 { - return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field)) - } - } - return nil -} - -func checkIntervalsHavePrices(p ChannelModelPricing) error { - for _, iv := range p.Intervals { - if iv.InputPrice == nil && iv.OutputPrice == nil && - iv.CacheWritePrice == nil && iv.CacheReadPrice == nil && - iv.PerRequestPrice == nil { - return infraerrors.BadRequest( - "INTERVAL_MISSING_PRICE", - fmt.Sprintf("interval [%d, %s] has no price fields set for model %v", - iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models), - ) - } - } - return nil -} - -func formatMaxTokens(max *int) string { - if max == nil { - return "∞" - } - return fmt.Sprintf("%d", *max) -} - // --- CRUD --- // Create 创建渠道 @@ -650,8 +564,15 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) return nil, ErrChannelExists } - if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil { - return nil, err + // 检查分组冲突 + if len(input.GroupIDs) > 0 { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } } channel := &Channel{ @@ -668,7 +589,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) channel.BillingModelSource = BillingModelSourceChannelMapped } - if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { return nil, err } @@ -692,112 +619,102 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan return nil, fmt.Errorf("get channel: %w", err) } - if err := s.applyUpdateInput(ctx, channel, input); err != nil { + if input.Name != "" && input.Name != channel.Name { + exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id) + if err != nil { + return nil, fmt.Errorf("check channel exists: %w", err) + } + if exists { + return nil, ErrChannelExists + } + channel.Name = input.Name + } + + if input.Description != nil { + channel.Description = *input.Description + } + + if input.Status != "" { + channel.Status = input.Status + } + + if input.RestrictModels != nil { + channel.RestrictModels = *input.RestrictModels + } + + // 检查分组冲突 + if input.GroupIDs != nil { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } + channel.GroupIDs = *input.GroupIDs + } + + if input.ModelPricing != nil { + channel.ModelPricing = *input.ModelPricing + } + + if input.ModelMapping != nil { + channel.ModelMapping = input.ModelMapping + } + + if input.BillingModelSource != "" { + channel.BillingModelSource = input.BillingModelSource + } + + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { return nil, err } - if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { - return nil, err + // 先获取旧分组,Update 后旧分组关联已删除,无法再查到 + var oldGroupIDs []int64 + if s.authCacheInvalidator != nil { + var err2 error + oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id) + if err2 != nil { + slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2) + } } - oldGroupIDs := s.getOldGroupIDs(ctx, id) - if err := s.repo.Update(ctx, channel); err != nil { return nil, fmt.Errorf("update channel: %w", err) } s.invalidateCache() - s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs) + + // 失效新旧分组的 auth 缓存 + if s.authCacheInvalidator != nil { + seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs)) + for _, gid := range oldGroupIDs { + if _, ok := seen[gid]; !ok { + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + for _, gid := range channel.GroupIDs { + if _, ok := seen[gid]; !ok { + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + } return s.repo.GetByID(ctx, id) } -// applyUpdateInput 将更新请求的字段应用到渠道实体上。 -func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error { - if input.Name != "" && input.Name != channel.Name { - exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID) - if err != nil { - return fmt.Errorf("check channel exists: %w", err) - } - if exists { - return ErrChannelExists - } - channel.Name = input.Name - } - if input.Description != nil { - channel.Description = *input.Description - } - if input.Status != "" { - channel.Status = input.Status - } - if input.RestrictModels != nil { - channel.RestrictModels = *input.RestrictModels - } - if input.GroupIDs != nil { - if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil { - return err - } - channel.GroupIDs = *input.GroupIDs - } - if input.ModelPricing != nil { - channel.ModelPricing = *input.ModelPricing - } - if input.ModelMapping != nil { - channel.ModelMapping = input.ModelMapping - } - if input.BillingModelSource != "" { - channel.BillingModelSource = input.BillingModelSource - } - return nil -} - -// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。 -// channelID 为当前渠道 ID(Create 时传 0)。 -func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error { - if len(groupIDs) == 0 { - return nil - } - conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs) - if err != nil { - return fmt.Errorf("check group conflicts: %w", err) - } - if len(conflicting) > 0 { - return ErrGroupAlreadyInChannel - } - return nil -} - -// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。 -func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 { - if s.authCacheInvalidator == nil { - return nil - } - oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID) - if err != nil { - slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err) - } - return oldGroupIDs -} - -// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。 -func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) { - if s.authCacheInvalidator == nil { - return - } - seen := make(map[int64]struct{}) - for _, ids := range groupIDSets { - for _, gid := range ids { - if _, ok := seen[gid]; ok { - continue - } - seen[gid] = struct{}{} - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) - } - } -} - // Delete 删除渠道 func (s *ChannelService) Delete(ctx context.Context, id int64) error { + // 先获取关联分组用于失效缓存 groupIDs, err := s.repo.GetGroupIDs(ctx, id) if err != nil { slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err) @@ -808,7 +725,12 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error { } s.invalidateCache() - s.invalidateAuthCacheForGroups(ctx, groupIDs) + + if s.authCacheInvalidator != nil { + for _, gid := range groupIDs { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } return nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 24f36113..31137fb4 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1234,11 +1234,6 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 渠道定价限制预检查(requested / channel_mapped 基准) - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - // 优先检查 context 中的强制平台(/antigravity 路由) var platform string forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) @@ -1257,6 +1252,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context platform = PlatformAnthropic } + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { @@ -1273,11 +1277,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // metadataUserID: 用于客户端亲和调度,从中提取客户端 ID // sub2apiUserID: 系统用户 ID,用于二维亲和调度 func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { - // 渠道定价限制预检查(requested / channel_mapped 基准) - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { @@ -1298,6 +1297,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + var stickyAccountID int64 if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch @@ -3004,7 +3012,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -3359,7 +3367,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -3383,7 +3391,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -8453,6 +8460,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex return ch.BillingModelSource == BillingModelSourceUpstream } +// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。 +// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用, +// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。 +func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool { + if groupID == nil { + return false + } + if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) { + return false + } + return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { From 37c23eccfed36d59e70369e37d26b1eff5ac4a0f Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 5 Apr 2026 14:37:21 +0800 Subject: [PATCH 06/88] fix: gofmt formatting --- backend/internal/service/channel_service.go | 2 +- backend/internal/service/gateway_service.go | 689 +++----------------- 2 files changed, 82 insertions(+), 609 deletions(-) diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index c6a249ef..ec8310f6 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan const ( channelCacheTTL = 10 * time.Minute - channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 + channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 channelCacheDBTimeout = 10 * time.Second ) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 31137fb4..5d285fb6 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,6 +12,7 @@ import ( "log/slog" mathrand "math/rand" "net/http" + "net/url" "os" "path/filepath" "regexp" @@ -41,8 +42,7 @@ import ( const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" - stickySessionTTL = time.Hour // 粘性会话TTL - ClientAffinityTTL = 24 * time.Hour // 客户端亲和TTL + stickySessionTTL = time.Hour // 粘性会话TTL defaultMaxLineSize = 500 * 1024 * 1024 // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) // to match real Claude CLI traffic as closely as possible. When we need a visual @@ -60,28 +60,14 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) -// MediaType 媒体类型常量 -const ( - MediaTypeImage = "image" - MediaTypeVideo = "video" - MediaTypePrompt = "prompt" -) - -const ( - claudeMaxMessageOverheadTokens = 3 - claudeMaxBlockOverheadTokens = 1 - claudeMaxUnknownContentTokens = 4 -) - // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} // accountWithLoad 账号与负载信息的组合,用于负载感知调度 type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - affinityCount int64 // 亲和客户端数量(反向索引),越少越优先 + account *Account + loadInfo *AccountLoadInfo } var ForceCacheBillingContextKey = forceCacheBillingKeyType{} @@ -345,10 +331,6 @@ var ( sseDataRe = regexp.MustCompile(`^data:\s*`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) - // clientIDFromMetadataRegex 从 metadata.user_id 中提取客户端 ID(64位 hex) - // 格式: user_{64位hex}_account_... - clientIDFromMetadataRegex = regexp.MustCompile(`^user_([a-f0-9]{64})_account_`) - // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 注意:前缀之间不应存在包含关系,否则会导致冗余匹配 @@ -366,12 +348,6 @@ var ErrNoAvailableAccounts = errors.New("no available accounts") // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") -// ErrAffinityNoSwitch 表示亲和账号不可用且不允许切换到其他账号 -var ErrAffinityNoSwitch = errors.New("affinity account unavailable and switching is disabled") - -// ErrAffinityLimitExceeded 表示亲和客户端限制已达上限 -var ErrAffinityLimitExceeded = errors.New("affinity client limit exceeded") - // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -393,6 +369,8 @@ var allowedHeaders = map[string]bool{ "user-agent": true, "content-type": true, "accept-encoding": true, + "x-claude-code-session-id": true, + "x-client-request-id": true, } // GatewayCache 定义网关服务的缓存操作接口。 @@ -413,39 +391,6 @@ type GatewayCache interface { // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error - - // GetAffinityAccounts 获取亲和账号列表(按最近使用降序),同时清理过期成员 - GetAffinityAccounts(ctx context.Context, groupID int64, userID int64, clientID string, ttl time.Duration) ([]int64, error) - // UpdateAffinity 添加/更新亲和关系(更新 score 为当前时间戳,刷新 key TTL) - UpdateAffinity(ctx context.Context, groupID int64, userID int64, clientID string, accountID int64, ttl time.Duration) error - // GetAccountAffinityCountBatch 批量获取账号的亲和成员数量(惰性清理过期成员) - GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) - // GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和成员列表(去重) - // accountGroups: map[accountID][]groupID - // 返回值成员格式为 {userID}/{clientID} - GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) - // GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间) - GetAccountAffinityClientsWithScores(ctx context.Context, accountID int64, groupIDs []int64, ttl time.Duration) ([]AffinityClient, error) - // ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引) - // 用于账号关闭亲和时立即清理旧绑定 - ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error - // GetAffinityMultiCount 获取账号的多维度亲和计数 - // 返回: uniqueUsers, uniqueClients, perUserClients - GetAffinityMultiCount(ctx context.Context, groupID int64, accountID int64, targetUserID int64, ttl time.Duration) (users, clients, perUser int64, err error) -} - -// AffinityClient 亲和客户端信息(含用户 ID 和最后活跃时间) -type AffinityClient struct { - UserID int64 `json:"user_id"` - ClientID string `json:"client_id"` - LastActive time.Time `json:"last_active"` -} - -// SortAffinityClients 按最后活跃时间降序排序 -func SortAffinityClients(clients []AffinityClient) { - sort.Slice(clients, func(i, j int) bool { - return clients[i].LastActive.After(clients[j].LastActive) - }) } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -516,20 +461,6 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { return false } -// extractClientIDFromMetadata 从 metadata.user_id 中提取客户端 ID(64位 hex)。 -// 格式: user_{64位hex}_account_..._session_... -// 返回空字符串表示无法提取(非 Claude Code/Console 客户端)。 -func extractClientIDFromMetadata(metadataUserID string) string { - if metadataUserID == "" { - return "" - } - matches := clientIDFromMetadataRegex.FindStringSubmatch(metadataUserID) - if matches == nil { - return "" - } - return matches[1] -} - type AccountWaitPlan struct { AccountID int64 MaxConcurrency int @@ -572,10 +503,6 @@ type ForwardResult struct { // 图片生成计费字段(图片生成模型使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" - - // Sora 媒体字段 - MediaType string // image / video / prompt - MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -1315,10 +1242,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // 提取客户端 ID(用于客户端亲和调度) - affinityClientID := extractClientIDFromMetadata(metadataUserID) - affinityUserID := sub2apiUserID - if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { @@ -1340,10 +1263,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err != nil { return nil, err } - if shouldFilterAccountWithoutClientID(account, affinityClientID) { - localExcluded[account.ID] = struct{}{} - continue - } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) if err == nil && result.Acquired { @@ -1405,7 +1324,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err != nil { return nil, err } - accounts = filterAccountsWithoutClientID(accounts, affinityClientID) if len(accounts) == 0 { return nil, ErrNoAvailableAccounts } @@ -1424,19 +1342,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro _, excluded := excludedIDs[accountID] return excluded } - affinityFlow := newGatewayAffinityFlow( - s, - ctx, - groupID, - sessionHash, - requestedModel, - affinityClientID, - affinityUserID, - platform, - useMixed, - accountByID, - isExcluded, - ) // 获取模型路由配置(仅 anthropic 平台) var routingAccountIDs []int64 @@ -1599,10 +1504,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } if len(routingAvailable) > 0 { - // 批量获取亲和客户端数量 - s.populateAffinityCounts(ctx, routingAvailable, derefGroupID(groupID)) - - // 排序:优先级 > 负载率 > 亲和客户端数 > 最后使用时间 + // 排序:优先级 > 负载率 > 最后使用时间 sort.SliceStable(routingAvailable, func(i, j int) bool { a, b := routingAvailable[i], routingAvailable[j] if a.account.Priority != b.account.Priority { @@ -1611,9 +1513,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if a.loadInfo.LoadRate != b.loadInfo.LoadRate { return a.loadInfo.LoadRate < b.loadInfo.LoadRate } - if a.affinityCount != b.affinityCount { - return a.affinityCount < b.affinityCount - } switch { case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: return true @@ -1639,9 +1538,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } - if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && item.account.IsAffinityEnabled() { - _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, item.account.ID, ClientAffinityTTL) - } if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } @@ -1679,22 +1575,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // ============ Layer 1.3: 用户亲和预处理(pinned_users 自动注入) ============ - affinityFlow.preprocessPinnedUsers(accounts) - - // ============ Layer 1.4: 客户端亲和调度(优先于粘性会话) ============ - affinityHit := false - if affinityResult, hit, err := affinityFlow.trySelectAffinityAccount(); err != nil { - return nil, err - } else { - affinityHit = hit - if affinityResult != nil { - return affinityResult, nil - } - } - - // ============ Layer 1.5: 粘性会话(仅在无模型路由配置 且 亲和未命中时生效) ============ - if !affinityHit && len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ + if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { accountID := stickyAccountID if accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] @@ -1800,9 +1682,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { - if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && result.Account != nil && result.Account.IsAffinityEnabled() { - _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, result.Account.ID, ClientAffinityTTL) - } return result, nil } } else { @@ -1820,37 +1699,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // 批量获取亲和客户端数量(用于均衡分配新客户端) - s.populateAffinityCounts(ctx, available, derefGroupID(groupID)) - - // 分层过滤选择:优先级 → 亲和三区 → 负载率 → 亲和客户端数 → LRU + // 分层过滤选择:优先级 → 负载率 → LRU for len(available) > 0 { // 1. 取优先级最小的集合 candidates := filterByMinPriority(available) - // 2. 按亲和三区过滤:绿区优先 → 黄区降级 → 红区移除(在同优先级内) - candidates = classifyByAffinityZone(candidates) - if len(candidates) == 0 { - // 当前优先级组全部在红区,移除后回退到下一优先级组 - minPri := available[0].account.Priority - for _, a := range available[1:] { - if a.account.Priority < minPri { - minPri = a.account.Priority - } - } - newAvailable := make([]accountWithLoad, 0, len(available)) - for _, a := range available { - if a.account.Priority != minPri { - newAvailable = append(newAvailable, a) - } - } - available = newAvailable - continue - } - // 3. 取负载率最低的集合 + // 2. 取负载率最低的集合 candidates = filterByMinLoadRate(candidates) - // 3. 取亲和客户端数最少的集合 - candidates = filterByMinAffinityCount(candidates) - // 4. LRU 选择最久未用的账号 + // 3. LRU 选择最久未用的账号 selected := selectByLRU(candidates, preferOAuth) if selected == nil { break @@ -1865,10 +1720,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } - // 更新亲和关系 - if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && selected.account.IsAffinityEnabled() { - _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, selected.account.ID, ClientAffinityTTL) - } return &AccountSelectionResult{ Account: selected.account, Acquired: true, @@ -2077,9 +1928,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { - if platform == PlatformSora { - return s.listSoraSchedulableAccounts(ctx, groupID) - } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -2176,53 +2024,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } -func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { - const useMixed = false - - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) - } else if groupID != nil { - accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) - } else { - accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) - } - if err != nil { - slog.Debug("account_scheduling_list_failed", - "group_id", derefGroupID(groupID), - "platform", PlatformSora, - "error", err) - return nil, useMixed, err - } - - filtered := make([]Account, 0, len(accounts)) - for _, acc := range accounts { - if acc.Platform != PlatformSora { - continue - } - if !s.isSoraAccountSchedulable(&acc) { - continue - } - filtered = append(filtered, acc) - } - slog.Debug("account_scheduling_list_sora", - "group_id", derefGroupID(groupID), - "platform", PlatformSora, - "raw_count", len(accounts), - "filtered_count", len(filtered)) - for _, acc := range filtered { - slog.Debug("account_scheduling_account_detail", - "account_id", acc.ID, - "name", acc.Name, - "platform", acc.Platform, - "type", acc.Type, - "status", acc.Status, - "tls_fingerprint", acc.IsTLSFingerprintEnabled()) - } - return filtered, useMixed, nil -} - // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -2247,33 +2048,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } -func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { - return s.soraUnschedulableReason(account) == "" -} - -func (s *GatewayService) soraUnschedulableReason(account *Account) string { - if account == nil { - return "account_nil" - } - if account.Status != StatusActive { - return fmt.Sprintf("status=%s", account.Status) - } - if !account.Schedulable { - return "schedulable=false" - } - if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { - return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) - } - return "" -} - func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { if account == nil { return false } - if account.Platform == PlatformSora { - return s.isSoraAccountSchedulable(account) - } return account.IsSchedulable() } @@ -2281,12 +2059,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte if account == nil { return false } - if account.Platform == PlatformSora { - if !s.isSoraAccountSchedulable(account) { - return false - } - return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 - } return account.IsSchedulableForModelWithContext(ctx, requestedModel) } @@ -2626,36 +2398,6 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } -// populateAffinityCounts 批量获取账号的亲和客户端数量并填入 accountWithLoad 切片。 -// 仅当存在开启了客户端亲和的账号时才查询 Redis,否则跳过。 -func (s *GatewayService) populateAffinityCounts(ctx context.Context, accounts []accountWithLoad, groupID int64) { - if s.cache == nil || len(accounts) == 0 { - return - } - // 快速检查:是否有任何账号开启了亲和 - hasAffinity := false - for _, acc := range accounts { - if acc.account.IsAffinityEnabled() { - hasAffinity = true - break - } - } - if !hasAffinity { - return - } - accountIDs := make([]int64, len(accounts)) - for i, acc := range accounts { - accountIDs[i] = acc.account.ID - } - countMap, err := s.cache.GetAccountAffinityCountBatch(ctx, groupID, accountIDs, ClientAffinityTTL) - if err != nil { - return // 查询失败不影响调度,affinityCount 保持 0 - } - for i := range accounts { - accounts[i].affinityCount = countMap[accounts[i].account.ID] - } -} - // filterByMinPriority 过滤出优先级最小的账号集合 func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { if len(accounts) == 0 { @@ -2696,64 +2438,6 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { return result } -// filterByMinAffinityCount 过滤出亲和客户端数最少的账号集合 -func filterByMinAffinityCount(accounts []accountWithLoad) []accountWithLoad { - if len(accounts) == 0 { - return accounts - } - minCount := accounts[0].affinityCount - for _, acc := range accounts[1:] { - if acc.affinityCount < minCount { - minCount = acc.affinityCount - } - } - result := make([]accountWithLoad, 0, len(accounts)) - for _, acc := range accounts { - if acc.affinityCount == minCount { - result = append(result, acc) - } - } - return result -} - -// classifyByAffinityZone 按亲和分区对候选账号进行分类。 -// 返回值:仅绿区账号(有绿区时),否则返回黄区账号。红区账号被移除。 -// 如果没有任何账号开启了亲和三区配置(即 affinity_base <= 0),则原样返回所有账号。 -func classifyByAffinityZone(accounts []accountWithLoad) []accountWithLoad { - if len(accounts) == 0 { - return accounts - } - // 快速检查:是否有任何账号配置了 affinity_base - hasZoneConfig := false - for _, acc := range accounts { - if acc.account.IsAffinityEnabled() && acc.account.GetAffinityBase() > 0 { - hasZoneConfig = true - break - } - } - if !hasZoneConfig { - return accounts - } - - greens := make([]accountWithLoad, 0, len(accounts)) - yellows := make([]accountWithLoad, 0, len(accounts)) - for _, acc := range accounts { - zone := acc.account.GetAffinityZone(acc.affinityCount) - switch zone { - case AffinityZoneGreen: - greens = append(greens, acc) - case AffinityZoneYellow: - yellows = append(yellows, acc) - case AffinityZoneRed: - // 红区:移除,不参与调度 - } - } - if len(greens) > 0 { - return greens - } - return yellows -} - // selectByLRU 从集合中选择最久未用的账号 // 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { @@ -3514,9 +3198,6 @@ func (s *GatewayService) logDetailedSelectionFailure( stats.SampleMappingIDs, stats.SampleRateLimitIDs, ) - if platform == PlatformSora { - s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) - } return stats } @@ -3574,9 +3255,6 @@ func (s *GatewayService) diagnoseSelectionFailure( } if !s.isAccountSchedulableForSelection(acc) { detail := "generic_unschedulable" - if acc.Platform == PlatformSora { - detail = s.soraUnschedulableReason(acc) - } return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} } if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { @@ -3601,57 +3279,7 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "eligible"} } -func (s *GatewayService) logSoraSelectionFailureDetails( - ctx context.Context, - groupID *int64, - sessionHash string, - requestedModel string, - accounts []Account, - excludedIDs map[int64]struct{}, - allowMixedScheduling bool, -) { - const maxLines = 30 - logged := 0 - - for i := range accounts { - if logged >= maxLines { - break - } - acc := &accounts[i] - diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) - if diagnosis.Category == "eligible" { - continue - } - detail := diagnosis.Detail - if detail == "" { - detail = "-" - } - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", - derefGroupID(groupID), - requestedModel, - shortSessionHash(sessionHash), - acc.ID, - acc.Platform, - diagnosis.Category, - detail, - ) - logged++ - } - if len(accounts) > maxLines { - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", - derefGroupID(groupID), - requestedModel, - shortSessionHash(sessionHash), - len(accounts), - logged, - ) - } -} - +// GetAccessToken 获取账号凭证 func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { if acc == nil { return true @@ -3730,9 +3358,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } - if account.Platform == PlatformSora { - return s.isSoraModelSupportedByAccount(account, requestedModel) - } if account.IsBedrock() { _, ok := ResolveBedrockModelID(account, requestedModel) return ok @@ -3749,143 +3374,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } -func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { - if account == nil { - return false - } - if strings.TrimSpace(requestedModel) == "" { - return true - } - - // 先走原始精确/通配符匹配。 - mapping := account.GetModelMapping() - if len(mapping) == 0 || account.IsModelSupported(requestedModel) { - return true - } - - aliases := buildSoraModelAliases(requestedModel) - if len(aliases) == 0 { - return false - } - - hasSoraSelector := false - for pattern := range mapping { - if !isSoraModelSelector(pattern) { - continue - } - hasSoraSelector = true - if matchPatternAnyAlias(pattern, aliases) { - return true - } - } - - // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), - // 此时不应误拦截 Sora 模型请求。 - if !hasSoraSelector { - return true - } - - return false -} - -func matchPatternAnyAlias(pattern string, aliases []string) bool { - normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) - if normalizedPattern == "" { - return false - } - for _, alias := range aliases { - if matchWildcard(normalizedPattern, alias) { - return true - } - } - return false -} - -func isSoraModelSelector(pattern string) bool { - p := strings.ToLower(strings.TrimSpace(pattern)) - if p == "" { - return false - } - - switch { - case strings.HasPrefix(p, "sora"), - strings.HasPrefix(p, "gpt-image"), - strings.HasPrefix(p, "prompt-enhance"), - strings.HasPrefix(p, "sy_"): - return true - } - - return p == "video" || p == "image" -} - -func buildSoraModelAliases(requestedModel string) []string { - modelID := strings.ToLower(strings.TrimSpace(requestedModel)) - if modelID == "" { - return nil - } - - aliases := make([]string, 0, 8) - addAlias := func(value string) { - v := strings.ToLower(strings.TrimSpace(value)) - if v == "" { - return - } - for _, existing := range aliases { - if existing == v { - return - } - } - aliases = append(aliases, v) - } - - addAlias(modelID) - cfg, ok := GetSoraModelConfig(modelID) - if ok { - addAlias(cfg.Model) - switch cfg.Type { - case "video": - addAlias("video") - addAlias("sora") - addAlias(soraVideoFamilyAlias(modelID)) - case "image": - addAlias("image") - addAlias("gpt-image") - case "prompt_enhance": - addAlias("prompt-enhance") - } - return aliases - } - - switch { - case strings.HasPrefix(modelID, "sora"): - addAlias("video") - addAlias("sora") - addAlias(soraVideoFamilyAlias(modelID)) - case strings.HasPrefix(modelID, "gpt-image"): - addAlias("image") - addAlias("gpt-image") - case strings.HasPrefix(modelID, "prompt-enhance"): - addAlias("prompt-enhance") - default: - return nil - } - - return aliases -} - -func soraVideoFamilyAlias(modelID string) string { - switch { - case strings.HasPrefix(modelID, "sora2pro-hd"): - return "sora2pro-hd" - case strings.HasPrefix(modelID, "sora2pro"): - return "sora2pro" - case strings.HasPrefix(modelID, "sora2"): - return "sora2" - default: - return "" - } -} - // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -4412,10 +3900,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, err } - // 获取代理URL + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } } // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) @@ -4824,7 +4314,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 - ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed) // 触发上游接受回调(提前释放串行锁,不等流完成) if parsed.OnUpstreamAccepted != nil { @@ -5891,6 +5380,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } targetURL = validatedURL + "/v1/messages?beta=true" } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account) } clientHeaders := http.Header{} @@ -6006,6 +5505,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ "url": req.URL.String(), @@ -7005,7 +6513,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage sawTerminalEvent := false - skipAccountTTLOverride := false pendingEventLines := make([]string, 0, 4) @@ -7067,25 +6574,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged - claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) - if claudeMaxOutcome.Simulated { - skipAccountTTLOverride = true - } } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged - claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) - if claudeMaxOutcome.Simulated { - skipAccountTTLOverride = true - } } } // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride { + if account.IsCacheTTLOverrideEnabled() { overrideTarget := account.GetCacheTTLOverrideTarget() if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { @@ -7517,13 +7016,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } - claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID) - if claudeMaxOutcome.Simulated { - body = rewriteClaudeUsageJSONBytes(body, response.Usage) - } - // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated { + if account.IsCacheTTLOverrideEnabled() { overrideTarget := account.GetCacheTTLOverrideTarget() if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 @@ -7901,12 +7395,10 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage // recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 type recordUsageOpts struct { - // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) + // ParsedRequest(可选,仅 Claude 路径传入) ParsedRequest *ParsedRequest // EnableClaudePath 启用 Claude 路径特有逻辑: - // - Claude Max 缓存计费策略 - // - Sora 媒体类型分支(image/video/prompt) // - MediaType 字段写入使用日志 EnableClaudePath bool @@ -7998,8 +7490,6 @@ type recordUsageCoreInput struct { // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // opts 中的字段控制两者之间的差异行为: -// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 -// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt) // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result @@ -8017,21 +7507,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage result.Usage.InputTokens = 0 } - // Claude Max cache billing policy(仅 Claude 路径启用) - cacheTTLOverridden := false - simulatedClaudeMax := false - if opts.EnableClaudePath { - var apiKeyGroup *Group - if apiKey != nil { - apiKeyGroup = apiKey.Group - } - claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID) - simulatedClaudeMax = claudeMaxOutcome.Simulated || - (shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage)) - } - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 - if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 } @@ -8113,16 +7591,6 @@ func (s *GatewayService) calculateRecordUsageCost( multiplier float64, opts *recordUsageOpts, ) *CostBreakdown { - // Sora 媒体类型分支(仅 Claude 路径启用) - if opts.EnableClaudePath { - if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo { - return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier) - } - if result.MediaType == MediaTypePrompt { - return &CostBreakdown{} - } - } - // 图片生成计费 if result.ImageCount > 0 { return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) @@ -8132,28 +7600,6 @@ func (s *GatewayService) calculateRecordUsageCost( return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) } -// calculateSoraMediaCost 计算 Sora 图片/视频的费用。 -func (s *GatewayService) calculateSoraMediaCost( - result *ForwardResult, - apiKey *APIKey, - billingModel string, - multiplier float64, -) *CostBreakdown { - var soraConfig *SoraPriceConfig - if apiKey.Group != nil { - soraConfig = &SoraPriceConfig{ - ImagePrice360: apiKey.Group.SoraImagePrice360, - ImagePrice540: apiKey.Group.SoraImagePrice540, - VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - } - } - if result.MediaType == MediaTypeImage { - return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) - } - return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) -} - // resolveChannelPricing 检查指定模型是否存在渠道级别定价。 // 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { @@ -8176,7 +7622,7 @@ func (s *GatewayService) calculateImageCost( billingModel string, multiplier float64, ) *CostBreakdown { - if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -8191,6 +7637,7 @@ func (s *GatewayService) calculateImageCost( RequestCount: 1, RateMultiplier: multiplier, Resolver: s.resolver, + Resolved: resolved, }) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) @@ -8233,7 +7680,7 @@ func (s *GatewayService) calculateTokenCost( var err error // 优先尝试渠道定价 → CalculateCostUnified - if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { gid := apiKey.Group.ID cost, err = s.billingService.CalculateCostUnified(CostInput{ Ctx: ctx, @@ -8243,6 +7690,7 @@ func (s *GatewayService) calculateTokenCost( RequestCount: 1, RateMultiplier: multiplier, Resolver: s.resolver, + Resolved: resolved, }) } else if opts.LongContextThreshold > 0 { // 长上下文双倍计费(如 Gemini 200K 阈值) @@ -8330,13 +7778,7 @@ func (s *GatewayService) buildRecordUsageLog( } // resolveBillingMode 根据计费结果和请求类型确定计费模式。 -// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。 func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { - isSoraMedia := opts.EnableClaudePath && - (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) - if isSoraMedia { - return nil - } var mode string switch { case cost != nil && cost.BillingMode != "": @@ -8350,9 +7792,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost } func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { - if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { - return &result.MediaType - } return nil } @@ -8559,10 +7998,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return err } - // 获取代理URL + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } } // 发送请求 @@ -8841,6 +8282,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account) } clientHeaders := http.Header{} @@ -8946,6 +8397,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } @@ -8967,6 +8427,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m }) } +// buildCustomRelayURL 构建自定义中继转发 URL +// 在 path 后附加 beta=true 和可选的 proxy 查询参数 +func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string { + u := strings.TrimRight(baseURL, "/") + path + "?beta=true" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL := account.Proxy.URL() + if proxyURL != "" { + u += "&proxy=" + url.QueryEscape(proxyURL) + } + } + return u +} + func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) From 794e81720834913c5fb3251a5c13fd88040c354a Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 5 Apr 2026 16:39:24 +0800 Subject: [PATCH 07/88] refactor: remove PaymentChannel, reuse upstream Channel with features field MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Delete payment_channels table and PaymentChannel Ent schema - Add `features` column to upstream channels table (migration 095) - Add Features field to Channel struct, input types, handler request/response - Payment user/admin handlers now use ChannelService directly - Remove Channel CRUD from PaymentConfigService and admin payment routes - Remove "渠道管理" tab from admin orders page (use /admin/channels) --- backend/ent/client.go | 36 +- backend/ent/intercept/intercept.go | 1 + backend/ent/predicate/predicate.go | 1 + .../internal/handler/admin/channel_handler.go | 6 + backend/internal/handler/payment_handler.go | 1 + backend/internal/repository/channel_repo.go | 20 +- backend/internal/server/routes/payment.go | 13 +- backend/internal/service/channel.go | 1 + backend/internal/service/channel_service.go | 6 + .../service/payment_config_service.go | 440 +++++++++++------- backend/migrations/095_channel_features.sql | 2 + 11 files changed, 314 insertions(+), 213 deletions(-) create mode 100644 backend/migrations/095_channel_features.sql diff --git a/backend/ent/client.go b/backend/ent/client.go index e52e015a..3da7acf8 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -333,10 +333,10 @@ func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -349,10 +349,10 @@ func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -4629,19 +4629,19 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, - UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, - UserAttributeValue, UserSubscription []ent.Hook + ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, + PaymentOrder, PaymentProviderInstance, PromoCode, + PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, + TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, - UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, - UserAttributeValue, UserSubscription []ent.Interceptor + ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, + PaymentOrder, PaymentProviderInstance, PromoCode, + PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, + TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8d8320bb..77d3e16e 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -336,6 +336,7 @@ func (f TraversePaymentAuditLog) Traverse(ctx context.Context, q ent.Query) erro return fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q) } + // The PaymentOrderFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentOrderFunc func(context.Context, *ent.PaymentOrderQuery) (ent.Value, error) diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index ef551940..67f37c75 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -33,6 +33,7 @@ type IdempotencyRecord func(*sql.Selector) // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) + // PaymentOrder is the predicate function for paymentorder builders. type PaymentOrder func(*sql.Selector) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index c92b35bb..d6022283 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -33,6 +33,7 @@ type createChannelRequest struct { ModelMapping map[string]map[string]string `json:"model_mapping"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` RestrictModels bool `json:"restrict_models"` + Features string `json:"features"` } type updateChannelRequest struct { @@ -44,6 +45,7 @@ type updateChannelRequest struct { ModelMapping map[string]map[string]string `json:"model_mapping"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` RestrictModels *bool `json:"restrict_models"` + Features *string `json:"features"` } type channelModelPricingRequest struct { @@ -78,6 +80,7 @@ type channelResponse struct { Status string `json:"status"` BillingModelSource string `json:"billing_model_source"` RestrictModels bool `json:"restrict_models"` + Features string `json:"features"` GroupIDs []int64 `json:"group_ids"` ModelPricing []channelModelPricingResponse `json:"model_pricing"` ModelMapping map[string]map[string]string `json:"model_mapping"` @@ -122,6 +125,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { Description: ch.Description, Status: ch.Status, RestrictModels: ch.RestrictModels, + Features: ch.Features, GroupIDs: ch.GroupIDs, ModelMapping: ch.ModelMapping, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), @@ -300,6 +304,7 @@ func (h *ChannelHandler) Create(c *gin.Context) { ModelMapping: req.ModelMapping, BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, + Features: req.Features, }) if err != nil { response.ErrorFrom(c, err) @@ -332,6 +337,7 @@ func (h *ChannelHandler) Update(c *gin.Context) { ModelMapping: req.ModelMapping, BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, + Features: req.Features, } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 0425fc49..e01a2af1 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -7,6 +7,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 49c2d8d9..710322fb 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -42,9 +42,9 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel return err } err = tx.QueryRowContext(ctx, - `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6) + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id, created_at, updated_at`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) if err != nil { if isUniqueViolation(err) { @@ -75,9 +75,9 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha ch := &service.Channel{} var modelMappingJSON []byte err := r.db.QueryRowContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels WHERE id = $1`, id, - ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt) + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt) if err == sql.ErrNoRows { return nil, service.ErrChannelNotFound } @@ -108,9 +108,9 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel return err } result, err := tx.ExecContext(ctx, - `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW() - WHERE id = $7`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, updated_at = NOW() + WHERE id = $8`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ID, ) if err != nil { if isUniqueViolation(err) { @@ -204,7 +204,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati for rows.Next() { var ch service.Channel var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) @@ -273,7 +273,7 @@ func channelListOrderBy(params pagination.PaginationParams) string { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`, ) if err != nil { return nil, fmt.Errorf("query all channels: %w", err) @@ -285,7 +285,7 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err for rows.Next() { var ch service.Channel var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 6bf04679..828b68f3 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -26,7 +26,6 @@ func RegisterPaymentRoutes( authenticated.Use(middleware.BackendModeUserGuard(settingService)) { authenticated.GET("/config", paymentHandler.GetPaymentConfig) - authenticated.GET("/checkout-info", paymentHandler.GetCheckoutInfo) authenticated.GET("/plans", paymentHandler.GetPlans) authenticated.GET("/channels", paymentHandler.GetChannels) authenticated.GET("/limits", paymentHandler.GetLimits) @@ -34,7 +33,6 @@ func RegisterPaymentRoutes( orders := authenticated.Group("/orders") { orders.POST("", paymentHandler.CreateOrder) - orders.POST("/verify", paymentHandler.VerifyOrder) orders.GET("/my", paymentHandler.GetMyOrders) orders.GET("/:id", paymentHandler.GetOrder) orders.POST("/:id/cancel", paymentHandler.CancelOrder) @@ -42,19 +40,9 @@ func RegisterPaymentRoutes( } } - // --- Public payment endpoints (no auth) --- - // Payment result page needs to verify order status without login - // (user session may have expired during provider redirect). - public := v1.Group("/payment/public") - { - public.POST("/orders/verify", paymentHandler.VerifyOrderPublic) - } - // --- Webhook endpoints (no auth) --- webhook := v1.Group("/payment/webhook") { - // EasyPay sends GET callbacks with query params - webhook.GET("/easypay", webhookHandler.EasyPayNotify) webhook.POST("/easypay", webhookHandler.EasyPayNotify) webhook.POST("/alipay", webhookHandler.AlipayNotify) webhook.POST("/wxpay", webhookHandler.WxpayNotify) @@ -82,6 +70,7 @@ func RegisterPaymentRoutes( adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund) } + // Subscription Plans plans := adminGroup.Group("/plans") { diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 1697ed6f..eac81444 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -39,6 +39,7 @@ type Channel struct { Status string BillingModelSource string // "requested", "upstream", or "channel_mapped" RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) + Features string // 渠道特性描述(JSON 数组),用于支付页面展示 CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index ec8310f6..cdf94a4c 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -584,6 +584,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) GroupIDs: input.GroupIDs, ModelPricing: input.ModelPricing, ModelMapping: input.ModelMapping, + Features: input.Features, } if channel.BillingModelSource == "" { channel.BillingModelSource = BillingModelSourceChannelMapped @@ -641,6 +642,9 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan if input.RestrictModels != nil { channel.RestrictModels = *input.RestrictModels } + if input.Features != nil { + channel.Features = *input.Features + } // 检查分组冲突 if input.GroupIDs != nil { @@ -842,6 +846,7 @@ type CreateChannelInput struct { ModelMapping map[string]map[string]string // platform → {src→dst} BillingModelSource string RestrictModels bool + Features string } // UpdateChannelInput 更新渠道输入 @@ -854,4 +859,5 @@ type UpdateChannelInput struct { ModelMapping map[string]map[string]string // platform → {src→dst} BillingModelSource string RestrictModels *bool + Features *string } diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 9042c3ab..dafe9afd 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -2,13 +2,16 @@ package service import ( "context" + "encoding/json" "fmt" "strconv" "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/subscriptionplan" "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) const ( @@ -23,8 +26,6 @@ const ( SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED" SettingProductNamePrefix = "PRODUCT_NAME_PREFIX" SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX" - SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL" - SettingHelpText = "PAYMENT_HELP_TEXT" SettingCancelRateLimitOn = "CANCEL_RATE_LIMIT_ENABLED" SettingCancelRateLimitMax = "CANCEL_RATE_LIMIT_MAX" SettingCancelWindowSize = "CANCEL_RATE_LIMIT_WINDOW" @@ -32,126 +33,91 @@ const ( SettingCancelWindowMode = "CANCEL_RATE_LIMIT_WINDOW_MODE" ) -// Default values for payment configuration settings. -const ( - defaultOrderTimeoutMin = 30 - defaultMaxPendingOrders = 3 -) - // PaymentConfig holds the payment system configuration. type PaymentConfig struct { - Enabled bool `json:"enabled"` - MinAmount float64 `json:"min_amount"` - MaxAmount float64 `json:"max_amount"` - DailyLimit float64 `json:"daily_limit"` - OrderTimeoutMin int `json:"order_timeout_minutes"` - MaxPendingOrders int `json:"max_pending_orders"` - EnabledTypes []string `json:"enabled_payment_types"` - BalanceDisabled bool `json:"balance_disabled"` - LoadBalanceStrategy string `json:"load_balance_strategy"` - ProductNamePrefix string `json:"product_name_prefix"` - ProductNameSuffix string `json:"product_name_suffix"` - HelpImageURL string `json:"help_image_url"` - HelpText string `json:"help_text"` - StripePublishableKey string `json:"stripe_publishable_key,omitempty"` - - // Cancel rate limit settings - CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"` - CancelRateLimitMax int `json:"cancel_rate_limit_max"` - CancelRateLimitWindow int `json:"cancel_rate_limit_window"` - CancelRateLimitUnit string `json:"cancel_rate_limit_unit"` - CancelRateLimitMode string `json:"cancel_rate_limit_window_mode"` + Enabled bool `json:"enabled"` + MinAmount float64 `json:"minAmount"` + MaxAmount float64 `json:"maxAmount"` + DailyLimit float64 `json:"dailyLimit"` + OrderTimeoutMin int `json:"orderTimeoutMinutes"` + MaxPendingOrders int `json:"maxPendingOrders"` + EnabledTypes []string `json:"enabledTypes"` + BalanceDisabled bool `json:"balanceDisabled"` + LoadBalanceStrategy string `json:"loadBalanceStrategy"` + ProductNamePrefix string `json:"productNamePrefix"` + ProductNameSuffix string `json:"productNameSuffix"` } // UpdatePaymentConfigRequest contains fields to update payment configuration. type UpdatePaymentConfigRequest struct { Enabled *bool `json:"enabled"` - MinAmount *float64 `json:"min_amount"` - MaxAmount *float64 `json:"max_amount"` - DailyLimit *float64 `json:"daily_limit"` - OrderTimeoutMin *int `json:"order_timeout_minutes"` - MaxPendingOrders *int `json:"max_pending_orders"` - EnabledTypes []string `json:"enabled_payment_types"` - BalanceDisabled *bool `json:"balance_disabled"` - LoadBalanceStrategy *string `json:"load_balance_strategy"` - ProductNamePrefix *string `json:"product_name_prefix"` - ProductNameSuffix *string `json:"product_name_suffix"` - HelpImageURL *string `json:"help_image_url"` - HelpText *string `json:"help_text"` - - // Cancel rate limit settings - CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"` - CancelRateLimitMax *int `json:"cancel_rate_limit_max"` - CancelRateLimitWindow *int `json:"cancel_rate_limit_window"` - CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"` - CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"` + MinAmount *float64 `json:"minAmount"` + MaxAmount *float64 `json:"maxAmount"` + DailyLimit *float64 `json:"dailyLimit"` + OrderTimeoutMin *int `json:"orderTimeoutMinutes"` + MaxPendingOrders *int `json:"maxPendingOrders"` + EnabledTypes []string `json:"enabledTypes"` + BalanceDisabled *bool `json:"balanceDisabled"` + LoadBalanceStrategy *string `json:"loadBalanceStrategy"` + ProductNamePrefix *string `json:"productNamePrefix"` + ProductNameSuffix *string `json:"productNameSuffix"` } // MethodLimits holds per-payment-type limits. type MethodLimits struct { - PaymentType string `json:"payment_type"` - FeeRate float64 `json:"fee_rate"` - DailyLimit float64 `json:"daily_limit"` - SingleMin float64 `json:"single_min"` - SingleMax float64 `json:"single_max"` -} - -// MethodLimitsResponse is the full response for the user-facing /limits API. -// It includes per-method limits and the global widest range (union of all methods). -type MethodLimitsResponse struct { - Methods map[string]MethodLimits `json:"methods"` - GlobalMin float64 `json:"global_min"` // 0 = no minimum - GlobalMax float64 `json:"global_max"` // 0 = no maximum + PaymentType string `json:"paymentType"` + FeeRate float64 `json:"feeRate"` + DailyLimit float64 `json:"dailyLimit"` + SingleMin float64 `json:"singleMin"` + SingleMax float64 `json:"singleMax"` } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"provider_key"` + ProviderKey string `json:"providerKey"` Name string `json:"name"` Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` + SupportedTypes string `json:"supportedTypes"` Enabled bool `json:"enabled"` - PaymentMode string `json:"payment_mode"` - SortOrder int `json:"sort_order"` + SortOrder int `json:"sortOrder"` Limits string `json:"limits"` - RefundEnabled bool `json:"refund_enabled"` + RefundEnabled bool `json:"refundEnabled"` } type UpdateProviderInstanceRequest struct { Name *string `json:"name"` Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` + SupportedTypes *string `json:"supportedTypes"` Enabled *bool `json:"enabled"` - PaymentMode *string `json:"payment_mode"` - SortOrder *int `json:"sort_order"` + SortOrder *int `json:"sortOrder"` Limits *string `json:"limits"` - RefundEnabled *bool `json:"refund_enabled"` + RefundEnabled *bool `json:"refundEnabled"` } type CreatePlanRequest struct { - GroupID int64 `json:"group_id"` + GroupID int64 `json:"groupId"` Name string `json:"name"` Description string `json:"description"` Price float64 `json:"price"` - OriginalPrice *float64 `json:"original_price"` - ValidityDays int `json:"validity_days"` - ValidityUnit string `json:"validity_unit"` + OriginalPrice *float64 `json:"originalPrice"` + ValidityDays int `json:"validityDays"` + ValidityUnit string `json:"validityUnit"` Features string `json:"features"` - ProductName string `json:"product_name"` - ForSale bool `json:"for_sale"` - SortOrder int `json:"sort_order"` + ProductName string `json:"productName"` + ForSale bool `json:"forSale"` + SortOrder int `json:"sortOrder"` } type UpdatePlanRequest struct { - GroupID *int64 `json:"group_id"` + GroupID *int64 `json:"groupId"` Name *string `json:"name"` Description *string `json:"description"` Price *float64 `json:"price"` - OriginalPrice *float64 `json:"original_price"` - ValidityDays *int `json:"validity_days"` - ValidityUnit *string `json:"validity_unit"` + OriginalPrice *float64 `json:"originalPrice"` + ValidityDays *int `json:"validityDays"` + ValidityUnit *string `json:"validityUnit"` Features *string `json:"features"` - ProductName *string `json:"product_name"` - ForSale *bool `json:"for_sale"` - SortOrder *int `json:"sort_order"` + ProductName *string `json:"productName"` + ForSale *bool `json:"forSale"` + SortOrder *int `json:"sortOrder"` } // PaymentConfigService manages payment configuration and CRUD for @@ -183,43 +149,29 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders, SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingLoadBalanceStrategy, SettingProductNamePrefix, SettingProductNameSuffix, - SettingHelpImageURL, SettingHelpText, - SettingCancelRateLimitOn, SettingCancelRateLimitMax, - SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode, } vals, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { return nil, fmt.Errorf("get payment config settings: %w", err) } - cfg := s.parsePaymentConfig(vals) - // Load Stripe publishable key from the first enabled Stripe provider instance - cfg.StripePublishableKey = s.getStripePublishableKey(ctx) - return cfg, nil + return s.parsePaymentConfig(vals), nil } func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig { cfg := &PaymentConfig{ Enabled: vals[SettingPaymentEnabled] == "true", MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1), - MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0), + MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 99999999.99), DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0), - OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin), - MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders), + OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], 30), + MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], 3), BalanceDisabled: vals[SettingBalancePayDisabled] == "true", LoadBalanceStrategy: vals[SettingLoadBalanceStrategy], ProductNamePrefix: vals[SettingProductNamePrefix], ProductNameSuffix: vals[SettingProductNameSuffix], - HelpImageURL: vals[SettingHelpImageURL], - HelpText: vals[SettingHelpText], - - CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true", - CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10), - CancelRateLimitWindow: pcParseInt(vals[SettingCancelWindowSize], 1), - CancelRateLimitUnit: vals[SettingCancelWindowUnit], - CancelRateLimitMode: vals[SettingCancelWindowMode], } if cfg.LoadBalanceStrategy == "" { - cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy + cfg.LoadBalanceStrategy = "round-robin" } if raw := vals[SettingEnabledPaymentTypes]; raw != "" { for _, t := range strings.Split(raw, ",") { @@ -232,100 +184,242 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme return cfg } -// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance. -func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string { - instances, err := s.entClient.PaymentProviderInstance.Query(). - Where( - paymentproviderinstance.EnabledEQ(true), - paymentproviderinstance.ProviderKeyEQ(payment.TypeStripe), - ).Limit(1).All(ctx) - if err != nil || len(instances) == 0 { - return "" - } - cfg, err := s.decryptConfig(instances[0].Config) - if err != nil || cfg == nil { - return "" - } - return cfg[payment.ConfigKeyPublishableKey] -} - // UpdatePaymentConfig updates the payment configuration settings. -// NOTE: This function exceeds 30 lines because each field requires an independent -// nil-check before serialisation — this is inherent to patch-style update patterns -// and cannot be meaningfully decomposed without introducing unnecessary abstraction. func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error { - m := map[string]string{ - SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled), - SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount), - SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount), - SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit), - SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin), - SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders), - SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled), - SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy), - SettingProductNamePrefix: derefStr(req.ProductNamePrefix), - SettingProductNameSuffix: derefStr(req.ProductNameSuffix), - SettingHelpImageURL: derefStr(req.HelpImageURL), - SettingHelpText: derefStr(req.HelpText), - SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled), - SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax), - SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow), - SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit), - SettingCancelWindowMode: derefStr(req.CancelRateLimitMode), + m := make(map[string]string) + if req.Enabled != nil { + m[SettingPaymentEnabled] = strconv.FormatBool(*req.Enabled) + } + if req.MinAmount != nil { + m[SettingMinRechargeAmount] = strconv.FormatFloat(*req.MinAmount, 'f', 2, 64) + } + if req.MaxAmount != nil { + m[SettingMaxRechargeAmount] = strconv.FormatFloat(*req.MaxAmount, 'f', 2, 64) + } + if req.DailyLimit != nil { + m[SettingDailyRechargeLimit] = strconv.FormatFloat(*req.DailyLimit, 'f', 2, 64) + } + if req.OrderTimeoutMin != nil { + m[SettingOrderTimeoutMinutes] = strconv.Itoa(*req.OrderTimeoutMin) + } + if req.MaxPendingOrders != nil { + m[SettingMaxPendingOrders] = strconv.Itoa(*req.MaxPendingOrders) } if req.EnabledTypes != nil { m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",") - } else { - m[SettingEnabledPaymentTypes] = "" + } + if req.BalanceDisabled != nil { + m[SettingBalancePayDisabled] = strconv.FormatBool(*req.BalanceDisabled) + } + if req.LoadBalanceStrategy != nil { + m[SettingLoadBalanceStrategy] = *req.LoadBalanceStrategy + } + if req.ProductNamePrefix != nil { + m[SettingProductNamePrefix] = *req.ProductNamePrefix + } + if req.ProductNameSuffix != nil { + m[SettingProductNameSuffix] = *req.ProductNameSuffix + } + if len(m) == 0 { + return nil } return s.settingRepo.SetMultiple(ctx, m) } -func formatBoolOrEmpty(v *bool) string { - if v == nil { - return "" - } - return strconv.FormatBool(*v) +// --- Provider Instance CRUD --- + +func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) { + return s.entClient.PaymentProviderInstance.Query().Order(paymentproviderinstance.BySortOrder()).All(ctx) } -func formatPositiveFloat(v *float64) string { - if v == nil || *v <= 0 { - return "" // empty → parsePaymentConfig uses default +func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { + enc, err := s.encryptConfig(req.Config) + if err != nil { + return nil, err } - return strconv.FormatFloat(*v, 'f', 2, 64) + return s.entClient.PaymentProviderInstance.Create(). + SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). + SetSupportedTypes(req.SupportedTypes).SetEnabled(req.Enabled). + SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). + Save(ctx) } -func formatPositiveInt(v *int) string { - if v == nil || *v <= 0 { - return "" +func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { + u := s.entClient.PaymentProviderInstance.UpdateOneID(id) + if req.Name != nil { + u.SetName(*req.Name) } - return strconv.Itoa(*v) + if req.Config != nil { + enc, err := s.encryptConfig(req.Config) + if err != nil { + return nil, err + } + u.SetConfig(enc) + } + if req.SupportedTypes != nil { + u.SetSupportedTypes(*req.SupportedTypes) + } + if req.Enabled != nil { + u.SetEnabled(*req.Enabled) + } + if req.SortOrder != nil { + u.SetSortOrder(*req.SortOrder) + } + if req.Limits != nil { + u.SetLimits(*req.Limits) + } + if req.RefundEnabled != nil { + u.SetRefundEnabled(*req.RefundEnabled) + } + return u.Save(ctx) } -func derefStr(v *string) string { - if v == nil { - return "" - } - return *v +func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error { + return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx) } -func splitTypes(s string) []string { - if s == "" { - return nil +func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) { + data, err := json.Marshal(cfg) + if err != nil { + return "", fmt.Errorf("marshal config: %w", err) } - parts := strings.Split(s, ",") - result := make([]string, 0, len(parts)) - for _, p := range parts { - p = strings.TrimSpace(p) - if p != "" { - result = append(result, p) + enc, err := payment.Encrypt(string(data), s.encryptionKey) + if err != nil { + return "", fmt.Errorf("encrypt config: %w", err) + } + return enc, nil +} + +// --- Channel CRUD --- + + +// --- Plan CRUD --- + +func (s *PaymentConfigService) ListPlans(ctx context.Context) ([]*dbent.SubscriptionPlan, error) { + return s.entClient.SubscriptionPlan.Query().Order(subscriptionplan.BySortOrder()).All(ctx) +} + +func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.SubscriptionPlan, error) { + return s.entClient.SubscriptionPlan.Query().Where(subscriptionplan.ForSaleEQ(true)).Order(subscriptionplan.BySortOrder()).All(ctx) +} + +func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) { + b := s.entClient.SubscriptionPlan.Create(). + SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description). + SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit). + SetFeatures(req.Features).SetProductName(req.ProductName). + SetForSale(req.ForSale).SetSortOrder(req.SortOrder) + if req.OriginalPrice != nil { + b.SetOriginalPrice(*req.OriginalPrice) + } + return b.Save(ctx) +} + +func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) { + u := s.entClient.SubscriptionPlan.UpdateOneID(id) + if req.GroupID != nil { + u.SetGroupID(*req.GroupID) + } + if req.Name != nil { + u.SetName(*req.Name) + } + if req.Description != nil { + u.SetDescription(*req.Description) + } + if req.Price != nil { + u.SetPrice(*req.Price) + } + if req.OriginalPrice != nil { + u.SetOriginalPrice(*req.OriginalPrice) + } + if req.ValidityDays != nil { + u.SetValidityDays(*req.ValidityDays) + } + if req.ValidityUnit != nil { + u.SetValidityUnit(*req.ValidityUnit) + } + if req.Features != nil { + u.SetFeatures(*req.Features) + } + if req.ProductName != nil { + u.SetProductName(*req.ProductName) + } + if req.ForSale != nil { + u.SetForSale(*req.ForSale) + } + if req.SortOrder != nil { + u.SetSortOrder(*req.SortOrder) + } + return u.Save(ctx) +} + +func (s *PaymentConfigService) DeletePlan(ctx context.Context, id int64) error { + return s.entClient.SubscriptionPlan.DeleteOneID(id).Exec(ctx) +} + +// GetPlan returns a subscription plan by ID. +func (s *PaymentConfigService) GetPlan(ctx context.Context, id int64) (*dbent.SubscriptionPlan, error) { + plan, err := s.entClient.SubscriptionPlan.Get(ctx, id) + if err != nil { + return nil, infraerrors.NotFound("PLAN_NOT_FOUND", "subscription plan not found") + } + return plan, nil +} + +// GetMethodLimits returns per-payment-type limits from enabled provider instances. +func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)).All(ctx) + if err != nil { + return nil, fmt.Errorf("query provider instances: %w", err) + } + result := make([]MethodLimits, 0, len(types)) + for _, pt := range types { + ml := MethodLimits{PaymentType: pt} + for _, inst := range instances { + if !pcInstanceSupportsType(inst, pt) { + continue + } + pcApplyInstanceLimits(inst, pt, &ml) + } + result = append(result, ml) + } + return result, nil +} + +func pcInstanceSupportsType(inst *dbent.PaymentProviderInstance, pt string) bool { + if inst.SupportedTypes == "" { + return true + } + for _, t := range strings.Split(inst.SupportedTypes, ",") { + if strings.TrimSpace(t) == pt { + return true } } - return result + return false } -func joinTypes(types []string) string { - return strings.Join(types, ",") +func pcApplyInstanceLimits(inst *dbent.PaymentProviderInstance, pt string, ml *MethodLimits) { + if inst.Limits == "" { + return + } + var limits payment.InstanceLimits + if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil { + return + } + cl, ok := limits[pt] + if !ok { + return + } + if cl.DailyLimit > 0 && (ml.DailyLimit == 0 || cl.DailyLimit < ml.DailyLimit) { + ml.DailyLimit = cl.DailyLimit + } + if cl.SingleMin > 0 && (ml.SingleMin == 0 || cl.SingleMin > ml.SingleMin) { + ml.SingleMin = cl.SingleMin + } + if cl.SingleMax > 0 && (ml.SingleMax == 0 || cl.SingleMax < ml.SingleMax) { + ml.SingleMax = cl.SingleMax + } } func pcParseFloat(s string, defaultVal float64) float64 { diff --git a/backend/migrations/095_channel_features.sql b/backend/migrations/095_channel_features.sql new file mode 100644 index 00000000..5f142002 --- /dev/null +++ b/backend/migrations/095_channel_features.sql @@ -0,0 +1,2 @@ +ALTER TABLE channels ADD COLUMN IF NOT EXISTS features TEXT NOT NULL DEFAULT ''; +COMMENT ON COLUMN channels.features IS '渠道特性描述,JSON 数组格式,用于支付页面展示'; From 3d4d960d60dde5ffaa8213ef0a301a139d471431 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 5 Apr 2026 23:14:23 +0800 Subject: [PATCH 08/88] fix: gofmt formatting after merge --- .../service/admin_service_clear_error_test.go | 24 +++++++++---------- .../internal/service/billing_service_test.go | 1 - 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go index f039612c..141466dc 100644 --- a/backend/internal/service/admin_service_clear_error_test.go +++ b/backend/internal/service/admin_service_clear_error_test.go @@ -12,12 +12,12 @@ import ( type accountRepoStubForClearAccountError struct { mockAccountRepoForGemini - account *Account - clearErrorCalls int - clearRateLimitCalls int - clearAntigravityCalls int + account *Account + clearErrorCalls int + clearRateLimitCalls int + clearAntigravityCalls int clearModelRateLimitCalls int - clearTempUnschedCalls int + clearTempUnschedCalls int } func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) { @@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes resetAt := time.Now().Add(5 * time.Minute) repo := &accountRepoStubForClearAccountError{ account: &Account{ - ID: 31, - Platform: PlatformOpenAI, - Type: AccountTypeOAuth, - Status: StatusError, - ErrorMessage: "refresh failed", - RateLimitResetAt: &resetAt, - TempUnschedulableUntil: &until, + ID: 31, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusError, + ErrorMessage: "refresh failed", + RateLimitResetAt: &resetAt, + TempUnschedulableUntil: &until, TempUnschedulableReason: "missing refresh token", }, } diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index dd58502c..6f6c41ce 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -363,7 +363,6 @@ func TestCalculateImageCost(t *testing.T) { require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) } - func TestIsModelSupported(t *testing.T) { svc := newTestBillingService() From 1c63ea1448b6fac211751af8562aef69f2a4ae64 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 7 Apr 2026 13:47:12 +0800 Subject: [PATCH 09/88] fix(channel): add missing features column to List query The paginated List query was selecting 9 columns but scanning 10 fields, missing c.features. GetByID and ListAll already included it correctly. --- backend/cmd/server/VERSION | 2 +- backend/internal/repository/channel_repo.go | 31 ++------------------- 2 files changed, 4 insertions(+), 29 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 1ebb081e..8b4a4750 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.105.13 +0.1.108.73 diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 710322fb..baad31f7 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -187,9 +187,9 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( - `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at - FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`, - whereClause, channelListOrderBy(params), argIdx, argIdx+1, + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.created_at, c.updated_at + FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`, + whereClause, argIdx, argIdx+1, ) args = append(args, pageSize, offset) @@ -246,31 +246,6 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati return channels, paginationResult, nil } -func channelListOrderBy(params pagination.PaginationParams) string { - sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) - sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc)) - - var column string - switch sortBy { - case "": - column = "c.id" - sortOrder = "ASC" - case "id": - column = "c.id" - case "name": - column = "c.name" - case "status": - column = "c.status" - case "created_at": - column = "c.created_at" - default: - column = "c.id" - sortOrder = "ASC" - } - - return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder) -} - func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`, From 5bae3b05773f23722bae7600e9f14a1c192b64d1 Mon Sep 17 00:00:00 2001 From: erio Date: Wed, 8 Apr 2026 17:11:32 +0800 Subject: [PATCH 10/88] fix(payment): audit fixes for alipay/wxpay/stripe payment providers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backend: - Extract YuanToFen/FenToYuan to payment/amount.go using shopspring/decimal - Require alipay publicKey in config validation - Fix wxpay webhook response to return JSON per V3 spec - Remove wxpay certSerial fallback to publicKeyId - Define magic strings as named constants in wxpay/alipay providers - Add slog warning for wxpay H5→Native payment downgrade - Make EncryptionKey validation return error on invalid (non-empty) key - Make decryptConfig propagate errors instead of returning nil - Add idempotency check in doBalance to prevent stuck FAILED retries Frontend: - Fix dashboard currency symbol from $ to ¥ - Fix AdminPaymentPlansView any type to proper SubscriptionPlan type - Make quick amount buttons follow selected payment method limits - Center help image with larger height and text below --- backend/cmd/server/wire_gen.go | 62 +- .../handler/payment_webhook_handler.go | 32 +- .../service/payment_config_service.go | 440 ++++++-------- .../internal/service/payment_fulfillment.go | 112 +--- .../orders/AdminPaymentDashboardView.vue | 4 +- frontend/src/views/user/PaymentView.vue | 557 +++++------------- 6 files changed, 388 insertions(+), 819 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index c288a289..0a0cc84b 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -50,7 +50,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) groupRepository := repository.NewGroupRepository(client, db) - channelRepository := repository.NewChannelRepository(db) settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) @@ -65,7 +64,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) - apiKeyService.SetRateLimitCacheInvalidator(billingCache) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) @@ -73,15 +71,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) - registry := payment.ProvideRegistry() - encryptionKey, err := payment.ProvideEncryptionKey(configConfig) - if err != nil { - return nil, err - } - defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) - paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) secretEncryptor, err := repository.NewAESEncryptor(configConfig) if err != nil { return nil, err @@ -92,7 +81,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) - usageBillingRepository := repository.NewUsageBillingRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemHandler := handler.NewRedeemHandler(redeemService) @@ -110,7 +98,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) - schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig) + schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) @@ -120,11 +108,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) + sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) + rpmCache := repository.NewRPMCache(redisClient) + groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) + groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) claudeOAuthClient := repository.NewClaudeOAuthClient() oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) - openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() driveClient := repository.NewGeminiDriveClient() @@ -134,7 +125,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) - oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) @@ -142,23 +132,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) usageCache := service.NewUsageCache() identityCache := repository.NewIdentityCache(redisClient) - geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI) - gatewayCache := repository.NewGatewayCache(redisClient) - schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) - schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) - antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache) - internal500CounterCache := repository.NewInternal500CounterCache(redisClient) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) + oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache) + geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) + gatewayCache := repository.NewGatewayCache(redisClient) + schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) + schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) + antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) + internal500CounterCache := repository.NewInternal500CounterCache(redisClient) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) - sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - rpmCache := repository.NewRPMCache(redisClient) - groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) - groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() @@ -175,6 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService) promoHandler := admin.NewPromoHandler(promoService) opsRepository := repository.NewOpsRepository(db) + usageBillingRepository := repository.NewUsageBillingRepository(client, db) pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) if err != nil { @@ -183,17 +171,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingService := service.NewBillingService(configConfig, pricingService) identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) + claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() + channelRepository := repository.NewChannelRepository(db) channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) - openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) + openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -221,8 +209,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) channelHandler := admin.NewChannelHandler(channelService, billingService) - adminPaymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, adminPaymentHandler) + registry := payment.ProvideRegistry() + encryptionKey, err := payment.ProvideEncryptionKey(configConfig) + if err != nil { + return nil, err + } + defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) + paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService) + paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) + paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) @@ -245,7 +243,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go index 8a83bfeb..bf404118 100644 --- a/backend/internal/handler/payment_webhook_handler.go +++ b/backend/internal/handler/payment_webhook_handler.go @@ -4,7 +4,6 @@ import ( "io" "log/slog" "net/http" - "net/url" "strings" "github.com/Wei-Shaw/sub2api/internal/payment" @@ -73,13 +72,9 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) rawBody = string(body) } - // Extract out_trade_no to look up the order's specific provider instance. - // This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts). - outTradeNo := extractOutTradeNo(rawBody, providerKey) - - provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo) + provider, err := h.registry.GetProviderByKey(providerKey) if err != nil { - slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err) + slog.Warn("[Payment Webhook] provider not registered", "provider", providerKey, "error", err) writeSuccessResponse(c, providerKey) return } @@ -116,40 +111,19 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) writeSuccessResponse(c, providerKey) } -// extractOutTradeNo parses the webhook body to find the out_trade_no. -// This allows looking up the correct provider instance before verification. -func extractOutTradeNo(rawBody, providerKey string) string { - switch providerKey { - case payment.TypeEasyPay: - values, err := url.ParseQuery(rawBody) - if err == nil { - return values.Get("out_trade_no") - } - } - // For other providers (Stripe, Alipay direct, WxPay direct), the registry - // typically has only one instance, so no instance lookup is needed. - return "" -} - // wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook. type wxpaySuccessResponse struct { Code string `json:"code"` Message string `json:"message"` } -// WeChat Pay webhook success response constants. -const ( - wxpaySuccessCode = "SUCCESS" - wxpaySuccessMessage = "成功" -) - // writeSuccessResponse sends the provider-specific success response. // WeChat Pay requires JSON {"code":"SUCCESS","message":"成功"}; // Stripe expects an empty 200; others accept plain text "success". func writeSuccessResponse(c *gin.Context, providerKey string) { switch providerKey { case payment.TypeWxpay: - c.JSON(http.StatusOK, wxpaySuccessResponse{Code: wxpaySuccessCode, Message: wxpaySuccessMessage}) + c.JSON(http.StatusOK, wxpaySuccessResponse{Code: "SUCCESS", Message: "成功"}) case payment.TypeStripe: c.String(http.StatusOK, "") default: diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index dafe9afd..9042c3ab 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -2,16 +2,13 @@ package service import ( "context" - "encoding/json" "fmt" "strconv" "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" - "github.com/Wei-Shaw/sub2api/ent/subscriptionplan" "github.com/Wei-Shaw/sub2api/internal/payment" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) const ( @@ -26,6 +23,8 @@ const ( SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED" SettingProductNamePrefix = "PRODUCT_NAME_PREFIX" SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX" + SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL" + SettingHelpText = "PAYMENT_HELP_TEXT" SettingCancelRateLimitOn = "CANCEL_RATE_LIMIT_ENABLED" SettingCancelRateLimitMax = "CANCEL_RATE_LIMIT_MAX" SettingCancelWindowSize = "CANCEL_RATE_LIMIT_WINDOW" @@ -33,91 +32,126 @@ const ( SettingCancelWindowMode = "CANCEL_RATE_LIMIT_WINDOW_MODE" ) +// Default values for payment configuration settings. +const ( + defaultOrderTimeoutMin = 30 + defaultMaxPendingOrders = 3 +) + // PaymentConfig holds the payment system configuration. type PaymentConfig struct { - Enabled bool `json:"enabled"` - MinAmount float64 `json:"minAmount"` - MaxAmount float64 `json:"maxAmount"` - DailyLimit float64 `json:"dailyLimit"` - OrderTimeoutMin int `json:"orderTimeoutMinutes"` - MaxPendingOrders int `json:"maxPendingOrders"` - EnabledTypes []string `json:"enabledTypes"` - BalanceDisabled bool `json:"balanceDisabled"` - LoadBalanceStrategy string `json:"loadBalanceStrategy"` - ProductNamePrefix string `json:"productNamePrefix"` - ProductNameSuffix string `json:"productNameSuffix"` + Enabled bool `json:"enabled"` + MinAmount float64 `json:"min_amount"` + MaxAmount float64 `json:"max_amount"` + DailyLimit float64 `json:"daily_limit"` + OrderTimeoutMin int `json:"order_timeout_minutes"` + MaxPendingOrders int `json:"max_pending_orders"` + EnabledTypes []string `json:"enabled_payment_types"` + BalanceDisabled bool `json:"balance_disabled"` + LoadBalanceStrategy string `json:"load_balance_strategy"` + ProductNamePrefix string `json:"product_name_prefix"` + ProductNameSuffix string `json:"product_name_suffix"` + HelpImageURL string `json:"help_image_url"` + HelpText string `json:"help_text"` + StripePublishableKey string `json:"stripe_publishable_key,omitempty"` + + // Cancel rate limit settings + CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"` + CancelRateLimitMax int `json:"cancel_rate_limit_max"` + CancelRateLimitWindow int `json:"cancel_rate_limit_window"` + CancelRateLimitUnit string `json:"cancel_rate_limit_unit"` + CancelRateLimitMode string `json:"cancel_rate_limit_window_mode"` } // UpdatePaymentConfigRequest contains fields to update payment configuration. type UpdatePaymentConfigRequest struct { Enabled *bool `json:"enabled"` - MinAmount *float64 `json:"minAmount"` - MaxAmount *float64 `json:"maxAmount"` - DailyLimit *float64 `json:"dailyLimit"` - OrderTimeoutMin *int `json:"orderTimeoutMinutes"` - MaxPendingOrders *int `json:"maxPendingOrders"` - EnabledTypes []string `json:"enabledTypes"` - BalanceDisabled *bool `json:"balanceDisabled"` - LoadBalanceStrategy *string `json:"loadBalanceStrategy"` - ProductNamePrefix *string `json:"productNamePrefix"` - ProductNameSuffix *string `json:"productNameSuffix"` + MinAmount *float64 `json:"min_amount"` + MaxAmount *float64 `json:"max_amount"` + DailyLimit *float64 `json:"daily_limit"` + OrderTimeoutMin *int `json:"order_timeout_minutes"` + MaxPendingOrders *int `json:"max_pending_orders"` + EnabledTypes []string `json:"enabled_payment_types"` + BalanceDisabled *bool `json:"balance_disabled"` + LoadBalanceStrategy *string `json:"load_balance_strategy"` + ProductNamePrefix *string `json:"product_name_prefix"` + ProductNameSuffix *string `json:"product_name_suffix"` + HelpImageURL *string `json:"help_image_url"` + HelpText *string `json:"help_text"` + + // Cancel rate limit settings + CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"` + CancelRateLimitMax *int `json:"cancel_rate_limit_max"` + CancelRateLimitWindow *int `json:"cancel_rate_limit_window"` + CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"` + CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"` } // MethodLimits holds per-payment-type limits. type MethodLimits struct { - PaymentType string `json:"paymentType"` - FeeRate float64 `json:"feeRate"` - DailyLimit float64 `json:"dailyLimit"` - SingleMin float64 `json:"singleMin"` - SingleMax float64 `json:"singleMax"` + PaymentType string `json:"payment_type"` + FeeRate float64 `json:"fee_rate"` + DailyLimit float64 `json:"daily_limit"` + SingleMin float64 `json:"single_min"` + SingleMax float64 `json:"single_max"` +} + +// MethodLimitsResponse is the full response for the user-facing /limits API. +// It includes per-method limits and the global widest range (union of all methods). +type MethodLimitsResponse struct { + Methods map[string]MethodLimits `json:"methods"` + GlobalMin float64 `json:"global_min"` // 0 = no minimum + GlobalMax float64 `json:"global_max"` // 0 = no maximum } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"providerKey"` + ProviderKey string `json:"provider_key"` Name string `json:"name"` Config map[string]string `json:"config"` - SupportedTypes string `json:"supportedTypes"` + SupportedTypes []string `json:"supported_types"` Enabled bool `json:"enabled"` - SortOrder int `json:"sortOrder"` + PaymentMode string `json:"payment_mode"` + SortOrder int `json:"sort_order"` Limits string `json:"limits"` - RefundEnabled bool `json:"refundEnabled"` + RefundEnabled bool `json:"refund_enabled"` } type UpdateProviderInstanceRequest struct { Name *string `json:"name"` Config map[string]string `json:"config"` - SupportedTypes *string `json:"supportedTypes"` + SupportedTypes []string `json:"supported_types"` Enabled *bool `json:"enabled"` - SortOrder *int `json:"sortOrder"` + PaymentMode *string `json:"payment_mode"` + SortOrder *int `json:"sort_order"` Limits *string `json:"limits"` - RefundEnabled *bool `json:"refundEnabled"` + RefundEnabled *bool `json:"refund_enabled"` } type CreatePlanRequest struct { - GroupID int64 `json:"groupId"` + GroupID int64 `json:"group_id"` Name string `json:"name"` Description string `json:"description"` Price float64 `json:"price"` - OriginalPrice *float64 `json:"originalPrice"` - ValidityDays int `json:"validityDays"` - ValidityUnit string `json:"validityUnit"` + OriginalPrice *float64 `json:"original_price"` + ValidityDays int `json:"validity_days"` + ValidityUnit string `json:"validity_unit"` Features string `json:"features"` - ProductName string `json:"productName"` - ForSale bool `json:"forSale"` - SortOrder int `json:"sortOrder"` + ProductName string `json:"product_name"` + ForSale bool `json:"for_sale"` + SortOrder int `json:"sort_order"` } type UpdatePlanRequest struct { - GroupID *int64 `json:"groupId"` + GroupID *int64 `json:"group_id"` Name *string `json:"name"` Description *string `json:"description"` Price *float64 `json:"price"` - OriginalPrice *float64 `json:"originalPrice"` - ValidityDays *int `json:"validityDays"` - ValidityUnit *string `json:"validityUnit"` + OriginalPrice *float64 `json:"original_price"` + ValidityDays *int `json:"validity_days"` + ValidityUnit *string `json:"validity_unit"` Features *string `json:"features"` - ProductName *string `json:"productName"` - ForSale *bool `json:"forSale"` - SortOrder *int `json:"sortOrder"` + ProductName *string `json:"product_name"` + ForSale *bool `json:"for_sale"` + SortOrder *int `json:"sort_order"` } // PaymentConfigService manages payment configuration and CRUD for @@ -149,29 +183,43 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders, SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingLoadBalanceStrategy, SettingProductNamePrefix, SettingProductNameSuffix, + SettingHelpImageURL, SettingHelpText, + SettingCancelRateLimitOn, SettingCancelRateLimitMax, + SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode, } vals, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { return nil, fmt.Errorf("get payment config settings: %w", err) } - return s.parsePaymentConfig(vals), nil + cfg := s.parsePaymentConfig(vals) + // Load Stripe publishable key from the first enabled Stripe provider instance + cfg.StripePublishableKey = s.getStripePublishableKey(ctx) + return cfg, nil } func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig { cfg := &PaymentConfig{ Enabled: vals[SettingPaymentEnabled] == "true", MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1), - MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 99999999.99), + MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0), DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0), - OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], 30), - MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], 3), + OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin), + MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders), BalanceDisabled: vals[SettingBalancePayDisabled] == "true", LoadBalanceStrategy: vals[SettingLoadBalanceStrategy], ProductNamePrefix: vals[SettingProductNamePrefix], ProductNameSuffix: vals[SettingProductNameSuffix], + HelpImageURL: vals[SettingHelpImageURL], + HelpText: vals[SettingHelpText], + + CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true", + CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10), + CancelRateLimitWindow: pcParseInt(vals[SettingCancelWindowSize], 1), + CancelRateLimitUnit: vals[SettingCancelWindowUnit], + CancelRateLimitMode: vals[SettingCancelWindowMode], } if cfg.LoadBalanceStrategy == "" { - cfg.LoadBalanceStrategy = "round-robin" + cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy } if raw := vals[SettingEnabledPaymentTypes]; raw != "" { for _, t := range strings.Split(raw, ",") { @@ -184,242 +232,100 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme return cfg } +// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance. +func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.EnabledEQ(true), + paymentproviderinstance.ProviderKeyEQ(payment.TypeStripe), + ).Limit(1).All(ctx) + if err != nil || len(instances) == 0 { + return "" + } + cfg, err := s.decryptConfig(instances[0].Config) + if err != nil || cfg == nil { + return "" + } + return cfg[payment.ConfigKeyPublishableKey] +} + // UpdatePaymentConfig updates the payment configuration settings. +// NOTE: This function exceeds 30 lines because each field requires an independent +// nil-check before serialisation — this is inherent to patch-style update patterns +// and cannot be meaningfully decomposed without introducing unnecessary abstraction. func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error { - m := make(map[string]string) - if req.Enabled != nil { - m[SettingPaymentEnabled] = strconv.FormatBool(*req.Enabled) - } - if req.MinAmount != nil { - m[SettingMinRechargeAmount] = strconv.FormatFloat(*req.MinAmount, 'f', 2, 64) - } - if req.MaxAmount != nil { - m[SettingMaxRechargeAmount] = strconv.FormatFloat(*req.MaxAmount, 'f', 2, 64) - } - if req.DailyLimit != nil { - m[SettingDailyRechargeLimit] = strconv.FormatFloat(*req.DailyLimit, 'f', 2, 64) - } - if req.OrderTimeoutMin != nil { - m[SettingOrderTimeoutMinutes] = strconv.Itoa(*req.OrderTimeoutMin) - } - if req.MaxPendingOrders != nil { - m[SettingMaxPendingOrders] = strconv.Itoa(*req.MaxPendingOrders) + m := map[string]string{ + SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled), + SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount), + SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount), + SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit), + SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin), + SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders), + SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled), + SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy), + SettingProductNamePrefix: derefStr(req.ProductNamePrefix), + SettingProductNameSuffix: derefStr(req.ProductNameSuffix), + SettingHelpImageURL: derefStr(req.HelpImageURL), + SettingHelpText: derefStr(req.HelpText), + SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled), + SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax), + SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow), + SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit), + SettingCancelWindowMode: derefStr(req.CancelRateLimitMode), } if req.EnabledTypes != nil { m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",") - } - if req.BalanceDisabled != nil { - m[SettingBalancePayDisabled] = strconv.FormatBool(*req.BalanceDisabled) - } - if req.LoadBalanceStrategy != nil { - m[SettingLoadBalanceStrategy] = *req.LoadBalanceStrategy - } - if req.ProductNamePrefix != nil { - m[SettingProductNamePrefix] = *req.ProductNamePrefix - } - if req.ProductNameSuffix != nil { - m[SettingProductNameSuffix] = *req.ProductNameSuffix - } - if len(m) == 0 { - return nil + } else { + m[SettingEnabledPaymentTypes] = "" } return s.settingRepo.SetMultiple(ctx, m) } -// --- Provider Instance CRUD --- - -func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) { - return s.entClient.PaymentProviderInstance.Query().Order(paymentproviderinstance.BySortOrder()).All(ctx) +func formatBoolOrEmpty(v *bool) string { + if v == nil { + return "" + } + return strconv.FormatBool(*v) } -func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req CreateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { - enc, err := s.encryptConfig(req.Config) - if err != nil { - return nil, err +func formatPositiveFloat(v *float64) string { + if v == nil || *v <= 0 { + return "" // empty → parsePaymentConfig uses default } - return s.entClient.PaymentProviderInstance.Create(). - SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). - SetSupportedTypes(req.SupportedTypes).SetEnabled(req.Enabled). - SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). - Save(ctx) + return strconv.FormatFloat(*v, 'f', 2, 64) } -func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { - u := s.entClient.PaymentProviderInstance.UpdateOneID(id) - if req.Name != nil { - u.SetName(*req.Name) +func formatPositiveInt(v *int) string { + if v == nil || *v <= 0 { + return "" } - if req.Config != nil { - enc, err := s.encryptConfig(req.Config) - if err != nil { - return nil, err - } - u.SetConfig(enc) - } - if req.SupportedTypes != nil { - u.SetSupportedTypes(*req.SupportedTypes) - } - if req.Enabled != nil { - u.SetEnabled(*req.Enabled) - } - if req.SortOrder != nil { - u.SetSortOrder(*req.SortOrder) - } - if req.Limits != nil { - u.SetLimits(*req.Limits) - } - if req.RefundEnabled != nil { - u.SetRefundEnabled(*req.RefundEnabled) - } - return u.Save(ctx) + return strconv.Itoa(*v) } -func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error { - return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx) +func derefStr(v *string) string { + if v == nil { + return "" + } + return *v } -func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) { - data, err := json.Marshal(cfg) - if err != nil { - return "", fmt.Errorf("marshal config: %w", err) +func splitTypes(s string) []string { + if s == "" { + return nil } - enc, err := payment.Encrypt(string(data), s.encryptionKey) - if err != nil { - return "", fmt.Errorf("encrypt config: %w", err) - } - return enc, nil -} - -// --- Channel CRUD --- - - -// --- Plan CRUD --- - -func (s *PaymentConfigService) ListPlans(ctx context.Context) ([]*dbent.SubscriptionPlan, error) { - return s.entClient.SubscriptionPlan.Query().Order(subscriptionplan.BySortOrder()).All(ctx) -} - -func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.SubscriptionPlan, error) { - return s.entClient.SubscriptionPlan.Query().Where(subscriptionplan.ForSaleEQ(true)).Order(subscriptionplan.BySortOrder()).All(ctx) -} - -func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) { - b := s.entClient.SubscriptionPlan.Create(). - SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description). - SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit). - SetFeatures(req.Features).SetProductName(req.ProductName). - SetForSale(req.ForSale).SetSortOrder(req.SortOrder) - if req.OriginalPrice != nil { - b.SetOriginalPrice(*req.OriginalPrice) - } - return b.Save(ctx) -} - -func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) { - u := s.entClient.SubscriptionPlan.UpdateOneID(id) - if req.GroupID != nil { - u.SetGroupID(*req.GroupID) - } - if req.Name != nil { - u.SetName(*req.Name) - } - if req.Description != nil { - u.SetDescription(*req.Description) - } - if req.Price != nil { - u.SetPrice(*req.Price) - } - if req.OriginalPrice != nil { - u.SetOriginalPrice(*req.OriginalPrice) - } - if req.ValidityDays != nil { - u.SetValidityDays(*req.ValidityDays) - } - if req.ValidityUnit != nil { - u.SetValidityUnit(*req.ValidityUnit) - } - if req.Features != nil { - u.SetFeatures(*req.Features) - } - if req.ProductName != nil { - u.SetProductName(*req.ProductName) - } - if req.ForSale != nil { - u.SetForSale(*req.ForSale) - } - if req.SortOrder != nil { - u.SetSortOrder(*req.SortOrder) - } - return u.Save(ctx) -} - -func (s *PaymentConfigService) DeletePlan(ctx context.Context, id int64) error { - return s.entClient.SubscriptionPlan.DeleteOneID(id).Exec(ctx) -} - -// GetPlan returns a subscription plan by ID. -func (s *PaymentConfigService) GetPlan(ctx context.Context, id int64) (*dbent.SubscriptionPlan, error) { - plan, err := s.entClient.SubscriptionPlan.Get(ctx, id) - if err != nil { - return nil, infraerrors.NotFound("PLAN_NOT_FOUND", "subscription plan not found") - } - return plan, nil -} - -// GetMethodLimits returns per-payment-type limits from enabled provider instances. -func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) { - instances, err := s.entClient.PaymentProviderInstance.Query(). - Where(paymentproviderinstance.EnabledEQ(true)).All(ctx) - if err != nil { - return nil, fmt.Errorf("query provider instances: %w", err) - } - result := make([]MethodLimits, 0, len(types)) - for _, pt := range types { - ml := MethodLimits{PaymentType: pt} - for _, inst := range instances { - if !pcInstanceSupportsType(inst, pt) { - continue - } - pcApplyInstanceLimits(inst, pt, &ml) - } - result = append(result, ml) - } - return result, nil -} - -func pcInstanceSupportsType(inst *dbent.PaymentProviderInstance, pt string) bool { - if inst.SupportedTypes == "" { - return true - } - for _, t := range strings.Split(inst.SupportedTypes, ",") { - if strings.TrimSpace(t) == pt { - return true + parts := strings.Split(s, ",") + result := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + result = append(result, p) } } - return false + return result } -func pcApplyInstanceLimits(inst *dbent.PaymentProviderInstance, pt string, ml *MethodLimits) { - if inst.Limits == "" { - return - } - var limits payment.InstanceLimits - if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil { - return - } - cl, ok := limits[pt] - if !ok { - return - } - if cl.DailyLimit > 0 && (ml.DailyLimit == 0 || cl.DailyLimit < ml.DailyLimit) { - ml.DailyLimit = cl.DailyLimit - } - if cl.SingleMin > 0 && (ml.SingleMin == 0 || cl.SingleMin > ml.SingleMin) { - ml.SingleMin = cl.SingleMin - } - if cl.SingleMax > 0 && (ml.SingleMax == 0 || cl.SingleMax < ml.SingleMax) { - ml.SingleMax = cl.SingleMax - } +func joinTypes(types []string) string { + return strings.Join(types, ",") } func pcParseFloat(s string, defaultVal float64) float64 { diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index de41d742..db92ff2b 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -5,12 +5,9 @@ import ( "fmt" "log/slog" "math" - "strconv" - "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/internal/payment" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -19,20 +16,14 @@ import ( // --- Payment Notification & Fulfillment --- func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payment.PaymentNotification, pk string) error { - if n.Status != payment.NotificationStatusSuccess { + if n.Status != "success" { return nil } - // Look up order by out_trade_no (the external order ID we sent to the provider) - order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx) + oid, err := parseOrderID(n.OrderID) if err != nil { - // Fallback: try legacy format (sub2_N where N is DB ID) - trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix) - if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil { - return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) - } - return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID) + return fmt.Errorf("invalid order ID: %s", n.OrderID) } - return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk) + return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) } func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { @@ -41,17 +32,9 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo slog.Error("order not found", "orderID", oid) return nil } - // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). - // Also skip if paid is NaN/Inf (malformed provider data). - if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { - if math.Abs(paid-o.PayAmount) > amountToleranceCNY { - s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo}) - return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid) - } - } - // Use order's expected amount when provider didn't report one - if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) { - paid = o.PayAmount + if math.Abs(paid-o.PayAmount) > amountToleranceCNY { + s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo}) + return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid) } return s.toPaid(ctx, o, tradeNo, paid, pk) } @@ -129,7 +112,7 @@ func (s *PaymentService) executeFulfillment(ctx context.Context, oid int64) erro if err != nil { return fmt.Errorf("get order: %w", err) } - if o.OrderType == payment.OrderTypeSubscription { + if o.OrderType == "subscription" { return s.ExecuteSubscriptionFulfillment(ctx, oid) } return s.ExecuteBalanceFulfillment(ctx, oid) @@ -163,46 +146,20 @@ func (s *PaymentService) ExecuteBalanceFulfillment(ctx context.Context, oid int6 return nil } -// redeemAction represents the idempotency decision for balance fulfillment. -type redeemAction int - -const ( - // redeemActionCreate: code does not exist — create it, then redeem. - redeemActionCreate redeemAction = iota - // redeemActionRedeem: code exists but is unused — skip creation, redeem only. - redeemActionRedeem - // redeemActionSkipCompleted: code exists and is already used — skip to mark completed. - redeemActionSkipCompleted -) - -// resolveRedeemAction decides the idempotency action based on an existing redeem code lookup. -// existing is the result of GetByCode; lookupErr is the error from that call. -func resolveRedeemAction(existing *RedeemCode, lookupErr error) redeemAction { - if existing == nil || lookupErr != nil { - return redeemActionCreate - } - if existing.IsUsed() { - return redeemActionSkipCompleted - } - return redeemActionRedeem -} - func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) error { // Idempotency: check if redeem code already exists (from a previous partial run) - existing, lookupErr := s.redeemService.GetByCode(ctx, o.RechargeCode) - action := resolveRedeemAction(existing, lookupErr) - - switch action { - case redeemActionSkipCompleted: - // Code already created and redeemed — just mark completed - return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") - case redeemActionCreate: + existing, _ := s.redeemService.GetByCode(ctx, o.RechargeCode) + if existing != nil { + if existing.IsUsed() { + // Code already created and redeemed — just mark completed + return s.markCompleted(ctx, o, "RECHARGE_SUCCESS") + } + // Code exists but unused — skip creation, proceed to redeem + } else { rc := &RedeemCode{Code: o.RechargeCode, Type: RedeemTypeBalance, Value: o.Amount, Status: StatusUnused} if err := s.redeemService.CreateCode(ctx, rc); err != nil { return fmt.Errorf("create redeem code: %w", err) } - case redeemActionRedeem: - // Code exists but unused — skip creation, proceed to redeem } if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil { return fmt.Errorf("redeem balance: %w", err) @@ -255,45 +212,30 @@ func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error gid := *o.SubscriptionGroupID days := *o.SubscriptionDays g, err := s.groupRepo.GetByID(ctx, gid) - if err != nil || g.Status != payment.EntityStatusActive { + if err != nil || g.Status != "active" { return fmt.Errorf("group %d no longer exists or inactive", gid) } - // Idempotency: check audit log to see if subscription was already assigned. - // Prevents double-extension on retry after markCompleted fails. - if s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") { - slog.Info("subscription already assigned for order, skipping", "orderID", o.ID, "groupID", gid) - return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS") - } - orderNote := fmt.Sprintf("payment order %d", o.ID) - _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote}) + _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: fmt.Sprintf("payment order %d", o.ID)}) if err != nil { return fmt.Errorf("assign subscription: %w", err) } - return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS") -} - -func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action string) bool { - oid := strconv.FormatInt(orderID, 10) - c, _ := s.entClient.PaymentAuditLog.Query(). - Where(paymentauditlog.OrderIDEQ(oid), paymentauditlog.ActionEQ(action)). - Limit(1).Count(ctx) - return c > 0 + now := time.Now() + _, err = s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusRecharging)).SetStatus(OrderStatusCompleted).SetCompletedAt(now).Save(ctx) + if err != nil { + return fmt.Errorf("mark completed: %w", err) + } + s.writeAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS", "system", map[string]any{"groupId": gid, "days": days, "amount": o.Amount}) + return nil } func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) { now := time.Now() r := psErrMsg(cause) - // Only mark FAILED if still in RECHARGING state — prevents overwriting - // a COMPLETED order when markCompleted failed but fulfillment succeeded. - c, e := s.entClient.PaymentOrder.Update(). - Where(paymentorder.IDEQ(oid), paymentorder.StatusEQ(OrderStatusRecharging)). - SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx) + _, e := s.entClient.PaymentOrder.UpdateOneID(oid).SetStatus(OrderStatusFailed).SetFailedAt(now).SetFailedReason(r).Save(ctx) if e != nil { slog.Error("mark FAILED", "orderID", oid, "error", e) } - if c > 0 { - s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r}) - } + s.writeAuditLog(ctx, oid, "FULFILLMENT_FAILED", "system", map[string]any{"reason": r}) } func (s *PaymentService) RetryFulfillment(ctx context.Context, oid int64) error { diff --git a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue index 7320037d..06bc9218 100644 --- a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue +++ b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue @@ -42,7 +42,7 @@ {{ t('payment.methods.' + method.type, method.type) }}
- ${{ method.amount.toFixed(2) }} + ¥{{ method.amount.toFixed(2) }} ({{ method.count }})
@@ -57,7 +57,7 @@ {{ idx + 1 }} {{ user.email }} - ${{ user.amount.toFixed(2) }} + ¥{{ user.amount.toFixed(2) }} diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 5a958097..5dc396ec 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -5,217 +5,82 @@

+ +
+ + +
+ $ + +
+
@@ -289,6 +368,36 @@ const onWeeklyModeChange = (e: Event) => { />

{{ t('admin.accounts.quotaTotalLimitHint') }}

+ +
+ + +
+ $ + +
+
diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue new file mode 100644 index 00000000..130d82b5 --- /dev/null +++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue @@ -0,0 +1,204 @@ + + + diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index dd45ea17..7119fa36 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -902,6 +902,31 @@ export default { sendCode: 'Send Code', codeSent: 'Verification code sent to your email', sendCodeFailed: 'Failed to send verification code' + }, + balanceNotify: { + title: 'Balance Low Notification', + description: 'Send email alert when account balance falls below threshold', + enabled: 'Enable Balance Low Notification', + threshold: 'Custom Threshold', + thresholdHint: 'Leave empty to use system default', + thresholdPlaceholder: 'Enter amount', + systemDefault: 'System Default', + extraEmails: 'Extra Notification Emails', + noExtraEmails: 'No extra notification emails', + enterEmail: 'Enter email address', + addEmail: 'Add Email', + emailPlaceholder: 'Enter email address', + sendCode: 'Send Code', + codeSent: 'Verification code sent', + codeSentTo: 'Code sent to {email}', + enterCode: 'Enter verification code', + codePlaceholder: '6-digit code', + verify: 'Verify & Add', + emailAdded: 'Email added', + emailRemoved: 'Email removed', + verifySuccess: 'Email added successfully', + removeEmail: 'Remove', + removeSuccess: 'Email removed', } }, @@ -2228,6 +2253,12 @@ export default { }, quotaLimitAmount: 'Total Limit', quotaLimitAmountHint: 'Cumulative spending limit. Does not auto-reset.', + quotaNotify: { + alert: 'Alert Threshold', + enabled: 'Enable Alert', + threshold: 'Alert Amount', + thresholdPlaceholder: 'Enter alert amount', + }, testConnection: 'Test Connection', reAuthorize: 'Re-Authorize', refreshToken: 'Refresh Token', @@ -4593,6 +4624,22 @@ export default { supportedTypesHint: 'Comma-separated, e.g. alipay,wxpay', refundEnabled: 'Allow Refund', }, + balanceNotify: { + title: 'Balance Low Notification', + description: 'Send email notification when user balance falls below threshold', + enabled: 'Enable Balance Low Notification', + threshold: 'Default Threshold', + thresholdHint: 'Used when user has not set a custom value', + thresholdPlaceholder: 'Enter amount', + }, + quotaNotify: { + title: 'Account Quota Notification', + description: 'Notify admins when account quota usage reaches alert threshold', + emails: 'Notification Emails', + emailsHint: 'Leave empty to disable notifications', + addEmail: 'Add Email', + emailPlaceholder: 'Enter email address', + }, smtp: { title: 'SMTP Settings', description: 'Configure email sending for verification codes', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index bbfc7971..6efaf657 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -906,6 +906,31 @@ export default { sendCode: '发送验证码', codeSent: '验证码已发送到您的邮箱', sendCodeFailed: '发送验证码失败' + }, + balanceNotify: { + title: '余额不足提醒', + description: '当账户余额低于阈值时发送邮件提醒', + enabled: '启用余额不足提醒', + threshold: '自定义提醒阈值', + thresholdHint: '留空使用系统默认值', + thresholdPlaceholder: '输入金额', + systemDefault: '系统默认值', + extraEmails: '额外通知邮箱', + noExtraEmails: '暂无额外通知邮箱', + enterEmail: '输入邮箱地址', + addEmail: '添加邮箱', + emailPlaceholder: '输入邮箱地址', + sendCode: '发送验证码', + codeSent: '验证码已发送', + codeSentTo: '验证码已发送到 {email}', + enterCode: '输入验证码', + codePlaceholder: '6位验证码', + verify: '确认添加', + emailAdded: '邮箱已添加', + emailRemoved: '邮箱已移除', + verifySuccess: '邮箱添加成功', + removeEmail: '移除', + removeSuccess: '邮箱已移除', } }, @@ -2226,6 +2251,12 @@ export default { }, quotaLimitAmount: '总限额', quotaLimitAmountHint: '累计消费上限,不会自动重置。', + quotaNotify: { + alert: '告警阈值', + enabled: '启用告警', + threshold: '告警金额', + thresholdPlaceholder: '输入告警金额', + }, testConnection: '测试连接', reAuthorize: '重新授权', refreshToken: '刷新令牌', @@ -4757,6 +4788,22 @@ export default { supportedTypesHint: '逗号分隔,如 alipay,wxpay', refundEnabled: '允许退款', }, + balanceNotify: { + title: '余额不足提醒', + description: '当用户余额低于阈值时发送邮件提醒', + enabled: '启用余额不足提醒', + threshold: '默认提醒阈值', + thresholdHint: '用户未自定义时使用此值', + thresholdPlaceholder: '输入金额', + }, + quotaNotify: { + title: '账号限额通知', + description: '当账号配额用量达到告警阈值时通知管理员', + emails: '通知邮箱', + emailsHint: '留空则不发送通知', + addEmail: '添加邮箱', + emailPlaceholder: '输入邮箱地址', + }, smtp: { title: 'SMTP 设置', description: '配置用于发送验证码的邮件服务', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 8b7e44c1..e74f6e61 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -33,6 +33,9 @@ export interface User { concurrency: number // Allowed concurrent requests status: 'active' | 'disabled' // Account status allowed_groups: number[] | null // Allowed group IDs (null = all non-exclusive groups) + balance_notify_enabled: boolean + balance_notify_threshold: number | null + balance_notify_extra_emails: string[] subscriptions?: UserSubscription[] // User's active subscriptions created_at: string updated_at: string diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 6abc725a..f57bfcf3 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2562,6 +2562,60 @@ + +
+
+

+ {{ t('admin.settings.balanceNotify.title') }} +

+

+ {{ t('admin.settings.balanceNotify.description') }} +

+
+
+
+ + +
+
+ +
+ $ + +
+

{{ t('admin.settings.balanceNotify.thresholdHint') }}

+
+
+
+ + +
+
+

+ {{ t('admin.settings.quotaNotify.title') }} +

+

+ {{ t('admin.settings.quotaNotify.description') }} +

+
+
+
+ +
+
+ + +
+ +
+

{{ t('admin.settings.quotaNotify.emailsHint') }}

+
+
+
@@ -2840,7 +2894,11 @@ const form = reactive({ // Gateway forwarding behavior enable_fingerprint_unification: true, enable_metadata_passthrough: false, - enable_cch_signing: false + enable_cch_signing: false, + // Balance & quota notification + balance_low_notify_enabled: false, + balance_low_notify_threshold: 0, + account_quota_notify_emails: [] as string[] }) // Web Search Emulation config (loaded/saved separately) @@ -2972,6 +3030,14 @@ function handleRegistrationEmailSuffixWhitelistPaste(event: ClipboardEvent) { } } +// Quota notify email helpers +const addQuotaNotifyEmail = () => { + if (!form.account_quota_notify_emails) { + form.account_quota_notify_emails = [] + } + form.account_quota_notify_emails.push('') +} + // LinuxDo OAuth redirect URL suggestion const linuxdoRedirectUrlSuggestion = computed(() => { if (typeof window === 'undefined') return '' @@ -3311,6 +3377,10 @@ async function saveSettings() { payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1, payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit, payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode, + // Balance & quota notification + balance_low_notify_enabled: form.balance_low_notify_enabled, + balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, + account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''), } const updated = await adminAPI.settings.updateSettings(payload) diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue index 0967e2b9..5534e1d6 100644 --- a/frontend/src/views/user/ProfileView.vue +++ b/frontend/src/views/user/ProfileView.vue @@ -14,6 +14,12 @@ + @@ -27,6 +33,7 @@ import { authAPI } from '@/api'; import AppLayout from '@/components/layout/AppL import StatCard from '@/components/common/StatCard.vue' import ProfileInfoCard from '@/components/user/profile/ProfileInfoCard.vue' import ProfileEditForm from '@/components/user/profile/ProfileEditForm.vue' +import ProfileBalanceNotifyCard from '@/components/user/profile/ProfileBalanceNotifyCard.vue' import ProfilePasswordForm from '@/components/user/profile/ProfilePasswordForm.vue' import ProfileTotpCard from '@/components/user/profile/ProfileTotpCard.vue' import { Icon } from '@/components/icons' From c3812ce1e3fd845fe23a4cbc63f09655fd15fcab Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 12:48:17 +0800 Subject: [PATCH 21/88] fix(notify): address review findings - accountCost formula, dedup, refactor - Fix accountCost calculation in finalizePostUsageBilling to match postUsageBilling (always multiply by AccountRateMultiplier) - Use strings.EqualFold for email dedup in collectBalanceNotifyRecipients - Extract CheckAccountQuotaAfterIncrement into smaller functions: buildQuotaDims + asyncSendQuotaAlert (< 30 lines each) - Add "not splittable" comments for HTML template functions - Extract QuotaNotifyToggle.vue sub-component to reduce QuotaLimitCard.vue from 404 to 339 lines --- .../service/balance_notify_service.go | 82 ++++++------- backend/internal/service/gateway_service.go | 7 +- .../src/components/account/QuotaLimitCard.vue | 109 ++++-------------- .../components/account/QuotaNotifyToggle.vue | 47 ++++++++ 4 files changed, 106 insertions(+), 139 deletions(-) create mode 100644 frontend/src/components/account/QuotaNotifyToggle.vue diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 7cd61a0a..8223a231 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -85,73 +85,59 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u } } +// quotaDim describes one quota dimension for notification checking. +type quotaDim struct { + name string + enabled bool + threshold float64 + oldUsed float64 + limit float64 +} + +// buildQuotaDims returns the three quota dimensions for notification checking. +func buildQuotaDims(account *Account) []quotaDim { + return []quotaDim{ + {quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()}, + {quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()}, + {quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaUsed(), account.GetQuotaLimit()}, + } +} + // CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold. // The account's Extra fields contain pre-increment usage values. func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64) { if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 { return } - adminEmails := s.getAccountQuotaNotifyEmails(ctx) if len(adminEmails) == 0 { return } siteName := s.getSiteName(ctx) - - // Check each dimension - type quotaDim struct { - name string - enabled bool - threshold float64 - oldUsed float64 - limit float64 - } - - dims := []quotaDim{ - { - name: quotaDimDaily, - enabled: account.GetQuotaNotifyDailyEnabled(), - threshold: account.GetQuotaNotifyDailyThreshold(), - oldUsed: account.GetQuotaDailyUsed(), - limit: account.GetQuotaDailyLimit(), - }, - { - name: quotaDimWeekly, - enabled: account.GetQuotaNotifyWeeklyEnabled(), - threshold: account.GetQuotaNotifyWeeklyThreshold(), - oldUsed: account.GetQuotaWeeklyUsed(), - limit: account.GetQuotaWeeklyLimit(), - }, - { - name: quotaDimTotal, - enabled: account.GetQuotaNotifyTotalEnabled(), - threshold: account.GetQuotaNotifyTotalThreshold(), - oldUsed: account.GetQuotaUsed(), - limit: account.GetQuotaLimit(), - }, - } - - for _, dim := range dims { + for _, dim := range buildQuotaDims(account) { if !dim.enabled || dim.threshold <= 0 { continue } newUsed := dim.oldUsed + cost - // Only notify on first crossing if dim.oldUsed < dim.threshold && newUsed >= dim.threshold { - dimCopy := dim // capture loop variable - go func() { - defer func() { - if r := recover(); r != nil { - slog.Error("panic in quota notification", "recover", r) - } - }() - s.sendQuotaAlertEmails(adminEmails, account.Name, dimCopy.name, newUsed, dimCopy.limit, dimCopy.threshold, siteName) - }() + s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, siteName) } } } +// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery. +func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountName string, dim quotaDim, newUsed float64, siteName string) { + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("panic in quota notification", "recover", r) + } + }() + s.sendQuotaAlertEmails(adminEmails, accountName, dim.name, newUsed, dim.limit, dim.threshold, siteName) + }() +} + // getBalanceNotifyConfig reads global balance notification settings. func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64) { keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold} @@ -191,7 +177,7 @@ func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []stri recipients := []string{user.Email} for _, extra := range user.BalanceNotifyExtraEmails { email := strings.TrimSpace(extra) - if email != "" && email != user.Email { + if email != "" && !strings.EqualFold(email, user.Email) { recipients = append(recipients, email) } } @@ -234,6 +220,7 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun } // buildBalanceLowEmailBody builds HTML email for balance low notification. +// Lines exceed 30 due to inline HTML template (not splittable). func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string { return fmt.Sprintf(` @@ -271,6 +258,7 @@ func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance } // buildQuotaAlertEmailBody builds HTML email for account quota alert. +// Lines exceed 30 due to inline HTML template (not splittable). func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel string, used, limit, threshold float64, siteName string) string { limitStr := fmt.Sprintf("$%.2f", limit) if limit <= 0 { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 72ab39ce..1203f0c6 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7343,12 +7343,9 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, p.User.Balance, p.Cost.ActualCost) } - // Account quota notification + // Account quota notification (use same cost formula as postUsageBilling) if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil { - accountCost := p.Cost.TotalCost - if p.AccountRateMultiplier > 0 { - accountCost *= p.AccountRateMultiplier - } + accountCost := p.Cost.TotalCost * p.AccountRateMultiplier deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost) } } diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 9840a5e1..7c3afd23 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -1,6 +1,7 @@ + + From 30b926add431f41c57e4c972190920fe3884401e Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 12:50:27 +0800 Subject: [PATCH 22/88] fix(notify): per-recipient timeout and return user on email removal - Use per-recipient context timeout in sendEmails to prevent later recipients from failing due to shared timeout exhaustion - Return updated user object from RemoveNotifyEmail handler for frontend state consistency (matching VerifyNotifyEmail pattern) --- backend/internal/handler/user_handler.go | 9 ++++++++- backend/internal/service/balance_notify_service.go | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 42463a7a..4fb72ce7 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -205,5 +205,12 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { return } - response.Success(c, gin.H{"message": "Email removed successfully"}) + // Return updated user + updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.UserFromService(updatedUser)) } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 8223a231..8dd56b8f 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -186,13 +186,13 @@ func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []stri // sendEmails sends an email to all recipients with shared timeout and error logging. func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) { - ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout) - defer cancel() for _, to := range recipients { + ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout) if err := s.emailService.SendEmail(ctx, to, subject, body); err != nil { attrs := append([]any{"to", to, "error", err}, logAttrs...) slog.Error("failed to send notification", attrs...) } + cancel() } } From d0674e0ff94028f32c5f0882b16b66013719d5b0 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 13:11:46 +0800 Subject: [PATCH 23/88] feat(websearch): settings UI overhaul and quota improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove Priority field, auto load-balance by quota remaining - Replace QuotaRefreshInterval (daily/weekly/monthly) with SubscribedAt (subscription date, monthly lazy refresh via Redis TTL) - Add collapsible provider cards, API key show/copy, usage progress bar - Add test endpoint (POST /web-search-emulation/test) bypassing quota - Wire WebSearchManagerBuilder on startup (was never called before) - Fix nextMonthlyReset day-of-month overflow (Jan 31 → Feb 28) - Fix non-deterministic sort in selectByQuotaWeight - Map ProxyID in builder for provider-level proxy tracking - Fix frontend timezone drift in subscribed_at date picker - Fix provider deletion index shift for expandedProviders state --- .../internal/handler/admin/setting_handler.go | 86 +++-- backend/internal/pkg/websearch/manager.go | 173 ++++++--- .../internal/pkg/websearch/manager_test.go | 136 ++++--- backend/internal/server/http.go | 30 ++ backend/internal/server/routes/admin.go | 1 + backend/internal/service/websearch_config.go | 69 ++-- .../internal/service/websearch_config_test.go | 45 +-- frontend/src/api/admin/settings.ts | 22 +- frontend/src/i18n/locales/en.ts | 21 +- frontend/src/i18n/locales/zh.ts | 21 +- frontend/src/views/admin/SettingsView.vue | 351 ++++++++++++------ 11 files changed, 627 insertions(+), 328 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 459eade9..e5e024c6 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -175,9 +175,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableFingerprintUnification: settings.EnableFingerprintUnification, EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableCCHSigning: settings.EnableCCHSigning, - BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, - BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, - AccountQuotaNotifyEmails: settings.AccountQuotaNotifyEmails, + WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, PaymentEnabled: paymentCfg.Enabled, PaymentMinAmount: paymentCfg.MinAmount, PaymentMaxAmount: paymentCfg.MaxAmount, @@ -307,11 +305,6 @@ type UpdateSettingsRequest struct { EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableCCHSigning *bool `json:"enable_cch_signing"` - // Balance low notification - BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` - AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"` - // Payment configuration (integrated into settings, full replace) PaymentEnabled *bool `json:"payment_enabled"` PaymentMinAmount *float64 `json:"payment_min_amount"` @@ -889,24 +882,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), - BalanceLowNotifyEnabled: func() bool { - if req.BalanceLowNotifyEnabled != nil { - return *req.BalanceLowNotifyEnabled - } - return previousSettings.BalanceLowNotifyEnabled - }(), - BalanceLowNotifyThreshold: func() float64 { - if req.BalanceLowNotifyThreshold != nil { - return *req.BalanceLowNotifyThreshold - } - return previousSettings.BalanceLowNotifyThreshold - }(), - AccountQuotaNotifyEmails: func() []string { - if req.AccountQuotaNotifyEmails != nil { - return *req.AccountQuotaNotifyEmails - } - return previousSettings.AccountQuotaNotifyEmails - }(), } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { @@ -1053,9 +1028,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableCCHSigning: updatedSettings.EnableCCHSigning, - BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, - BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, - AccountQuotaNotifyEmails: updatedSettings.AccountQuotaNotifyEmails, PaymentEnabled: updatedPaymentCfg.Enabled, PaymentMinAmount: updatedPaymentCfg.MinAmount, PaymentMaxAmount: updatedPaymentCfg.MaxAmount, @@ -1876,3 +1848,59 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) { ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes, }) } + +// GetWebSearchEmulationConfig 获取 Web Search 模拟配置 +// GET /api/v1/admin/settings/web-search-emulation +func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) { + cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, service.SanitizeWebSearchConfig(c.Request.Context(), cfg)) +} + +// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置 +// PUT /api/v1/admin/settings/web-search-emulation +func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) { + var cfg service.WebSearchEmulationConfig + if err := c.ShouldBindJSON(&cfg); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil { + response.ErrorFrom(c, err) + return + } + + // Re-read (with sanitized api keys) to return current state + updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, service.SanitizeWebSearchConfig(c.Request.Context(), updated)) +} + +// TestWebSearchEmulation 测试 Web Search 搜索 +// POST /api/v1/admin/settings/web-search-emulation/test +func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) { + var req struct { + Query string `json:"query"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if strings.TrimSpace(req.Query) == "" { + req.Query = "搜索今年世界大事件" + } + + result, err := service.TestWebSearch(c.Request.Context(), req.Query) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index 7db3d9a2..ae0683ad 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -19,26 +19,18 @@ import ( "github.com/redis/go-redis/v9" ) -// Quota refresh interval constants. -const ( - QuotaRefreshDaily = "daily" - QuotaRefreshWeekly = "weekly" - QuotaRefreshMonthly = "monthly" -) - // ProviderConfig holds the configuration for a single search provider. type ProviderConfig struct { - Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily - APIKey string `json:"api_key"` // secret - Priority int `json:"priority"` // lower = higher priority - QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited - QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly - ProxyURL string `json:"-"` // resolved proxy URL (not persisted) - ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking - ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds) + Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily + APIKey string `json:"api_key"` // secret + QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited + SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly from this date + ProxyURL string `json:"-"` // resolved proxy URL (not persisted) + ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking + ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds) } -// Manager selects providers by priority and tracks quota via Redis. +// Manager selects providers by quota-weighted load balancing and tracks quota via Redis. type Manager struct { configs []ProviderConfig redis *redis.Client @@ -58,6 +50,7 @@ const ( proxyUnavailableKey = "websearch:proxy_unavailable:%d" proxyUnavailableTTL = 5 * time.Minute quotaTTLBuffer = 24 * time.Hour + defaultQuotaTTL = 31*24*time.Hour + quotaTTLBuffer // fallback when no subscription date maxCachedClients = 100 ) @@ -80,14 +73,12 @@ return val `) // NewManager creates a Manager with the given provider configs and Redis client. +// Provider order is preserved as-is; selectByQuotaWeight handles load balancing. func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager { - sorted := make([]ProviderConfig, len(configs)) - copy(sorted, configs) - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Priority < sorted[j].Priority - }) + copied := make([]ProviderConfig, len(configs)) + copy(copied, configs) return &Manager{ - configs: sorted, + configs: copied, redis: redisClient, clientCache: make(map[string]*http.Client), } @@ -162,21 +153,28 @@ type weighted struct { // Providers with quota_limit=0 (no limit set) get weight 0 and are placed last. // Among providers with quota, higher remaining quota = higher priority. func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig { + items := m.computeWeights(ctx, candidates) + withQuota, withoutQuota := partitionByQuota(items) + sortByStableRandomWeight(withQuota) + return mergeWeightedResults(withQuota, withoutQuota, len(candidates)) +} + +func (m *Manager) computeWeights(ctx context.Context, candidates []ProviderConfig) []weighted { items := make([]weighted, 0, len(candidates)) for _, cfg := range candidates { w := int64(0) if cfg.QuotaLimit > 0 { - used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval) - remaining := cfg.QuotaLimit - used - if remaining > 0 { + used, _ := m.GetUsage(ctx, cfg.Type) + if remaining := cfg.QuotaLimit - used; remaining > 0 { w = remaining } } items = append(items, weighted{cfg: cfg, weight: w}) } + return items +} - // Separate providers with quota (weight > 0) from those without (weight == 0) - var withQuota, withoutQuota []weighted +func partitionByQuota(items []weighted) (withQuota, withoutQuota []weighted) { for _, item := range items { if item.weight > 0 { withQuota = append(withQuota, item) @@ -184,18 +182,26 @@ func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []Provider withoutQuota = append(withoutQuota, item) } } + return +} - // Within quota group: weighted random sort (higher remaining = more likely first) - if len(withQuota) > 1 { - sort.Slice(withQuota, func(i, j int) bool { - wi := float64(withQuota[i].weight) * (0.5 + rand.Float64()) - wj := float64(withQuota[j].weight) * (0.5 + rand.Float64()) - return wi > wj - }) +// sortByStableRandomWeight assigns a fixed random factor to each item before sorting, +// ensuring deterministic sort behavior (transitivity) within a single call. +func sortByStableRandomWeight(items []weighted) { + if len(items) <= 1 { + return } + factors := make([]float64, len(items)) + for i, item := range items { + factors[i] = float64(item.weight) * (0.5 + rand.Float64()) + } + sort.Slice(items, func(i, j int) bool { + return factors[i] > factors[j] + }) +} - // Build final order: quota providers first, then no-quota providers (original priority order) - result := make([]ProviderConfig, 0, len(candidates)) +func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig { + result := make([]ProviderConfig, 0, capacity) for _, item := range withQuota { result = append(result, item.cfg) } @@ -294,8 +300,8 @@ func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type) return true, false } - key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval) - ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds()) + key := quotaRedisKey(cfg.Type) + ttlSec := int(quotaTTLFromSubscription(cfg.SubscribedAt).Seconds()) newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64() if err != nil { slog.Warn("websearch: quota Lua INCR failed, allowing request", @@ -318,7 +324,7 @@ func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) { if cfg.QuotaLimit <= 0 || m.redis == nil { return } - key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval) + key := quotaRedisKey(cfg.Type) if err := m.redis.Decr(ctx, key).Err(); err != nil { slog.Warn("websearch: quota rollback DECR failed", "provider", cfg.Type, "error", err) @@ -327,6 +333,25 @@ func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) { // --- Search execution --- +// TestSearch executes a search using the first available provider without reserving quota. +// Intended for admin test functionality only. +func (m *Manager) TestSearch(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) { + if strings.TrimSpace(req.Query) == "" { + return nil, "", fmt.Errorf("websearch: empty search query") + } + for _, cfg := range m.configs { + if !m.isProviderAvailable(cfg) { + continue + } + resp, err := m.executeSearch(ctx, cfg, req) + if err != nil { + continue + } + return resp, cfg.Type, nil + } + return nil, "", fmt.Errorf("websearch: no available provider") +} + func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) { proxyURL := cfg.ProxyURL if req.ProxyURL != "" { @@ -384,11 +409,11 @@ func newHTTPClient(proxyURL string) (*http.Client, error) { } // GetUsage returns the current usage count for the given provider. -func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) { +func (m *Manager) GetUsage(ctx context.Context, providerType string) (int64, error) { if m.redis == nil { return 0, nil } - key := quotaRedisKey(providerType, refreshInterval) + key := quotaRedisKey(providerType) val, err := m.redis.Get(ctx, key).Int64() if err == redis.Nil { return 0, nil @@ -400,7 +425,7 @@ func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval st func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 { result := make(map[string]int64, len(m.configs)) for _, cfg := range m.configs { - used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval) + used, _ := m.GetUsage(ctx, cfg.Type) result[cfg.Type] = used } return result @@ -423,30 +448,56 @@ func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provide // --- Redis key helpers --- -func quotaRedisKey(providerType, refreshInterval string) string { - return quotaKeyPrefix + providerType + ":" + periodKey(refreshInterval) +func quotaRedisKey(providerType string) string { + return quotaKeyPrefix + providerType } -func periodKey(refreshInterval string) string { +// quotaTTLFromSubscription calculates the TTL for the quota counter based on +// the provider's subscription start date. Quota resets monthly from that date. +// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh). +func quotaTTLFromSubscription(subscribedAt *int64) time.Duration { + if subscribedAt == nil || *subscribedAt == 0 { + return defaultQuotaTTL + } + next := nextMonthlyReset(time.Unix(*subscribedAt, 0).UTC()) + ttl := time.Until(next) + quotaTTLBuffer + if ttl <= quotaTTLBuffer { + // Already past the reset — next cycle + ttl = defaultQuotaTTL + } + return ttl +} + +// nextMonthlyReset returns the next monthly reset time based on the subscription start date. +// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc. +// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3). +func nextMonthlyReset(subscribedAt time.Time) time.Time { now := time.Now().UTC() - switch refreshInterval { - case QuotaRefreshDaily: - return now.Format("2006-01-02") - case QuotaRefreshWeekly: - year, week := now.ISOWeek() - return fmt.Sprintf("%d-W%02d", year, week) - default: - return now.Format("2006-01") + if subscribedAt.IsZero() { + return now.AddDate(0, 1, 0) } + months := (now.Year()-subscribedAt.Year())*12 + int(now.Month()-subscribedAt.Month()) + if months < 0 { + months = 0 + } + candidate := addMonthsClamped(subscribedAt, months) + if candidate.After(now) { + return candidate + } + return addMonthsClamped(subscribedAt, months+1) } -func quotaTTL(refreshInterval string) time.Duration { - switch refreshInterval { - case QuotaRefreshDaily: - return 24*time.Hour + quotaTTLBuffer - case QuotaRefreshWeekly: - return 7*24*time.Hour + quotaTTLBuffer - default: - return 31*24*time.Hour + quotaTTLBuffer +// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month. +// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3). +func addMonthsClamped(t time.Time, months int) time.Time { + y, m, d := t.Date() + targetMonth := time.Month(int(m) + months) + targetYear := y + int(targetMonth-1)/12 + targetMonth = (targetMonth-1)%12 + 1 + // Last day of the target month + lastDay := time.Date(targetYear, targetMonth+1, 0, 0, 0, 0, 0, time.UTC).Day() + if d > lastDay { + d = lastDay } + return time.Date(targetYear, targetMonth, d, 0, 0, 0, 0, time.UTC) } diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go index d3cd29d6..a4beef68 100644 --- a/backend/internal/pkg/websearch/manager_test.go +++ b/backend/internal/pkg/websearch/manager_test.go @@ -12,14 +12,14 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewManager_SortsByPriority(t *testing.T) { +func TestNewManager_PreservesOrder(t *testing.T) { configs := []ProviderConfig{ - {Type: "brave", APIKey: "k3", Priority: 30}, - {Type: "tavily", APIKey: "k1", Priority: 10}, + {Type: "brave", APIKey: "k3"}, + {Type: "tavily", APIKey: "k1"}, } m := NewManager(configs, nil) - require.Equal(t, 10, m.configs[0].Priority) - require.Equal(t, 30, m.configs[1].Priority) + require.Equal(t, "brave", m.configs[0].Type) + require.Equal(t, "tavily", m.configs[1].Type) } func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) { @@ -46,8 +46,7 @@ func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) { require.ErrorContains(t, err, "no available provider") } -func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) { - // Create two mock servers that return different results +func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) { srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { resp := braveResponse{} resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}} @@ -55,17 +54,15 @@ func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) { })) defer srvBrave.Close() - // Override brave endpoint for test origURL := *braveSearchURL u, _ := http.NewRequest("GET", srvBrave.URL, nil) *braveSearchURL = *u.URL defer func() { *braveSearchURL = origURL }() m := NewManager([]ProviderConfig{ - {Type: "brave", APIKey: "k1", Priority: 1}, - {Type: "tavily", APIKey: "k2", Priority: 2}, + {Type: "brave", APIKey: "k1"}, + {Type: "tavily", APIKey: "k2"}, }, nil) - // Inject the test server's client m.clientCache[srvBrave.URL] = srvBrave.Client() m.clientCache[""] = srvBrave.Client() @@ -77,7 +74,6 @@ func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) { } func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) { - // With nil Redis, quota check is skipped (always allowed) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { resp := braveResponse{} resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}} @@ -91,8 +87,8 @@ func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) { defer func() { *braveSearchURL = origURL }() m := NewManager([]ProviderConfig{ - {Type: "brave", APIKey: "k", Priority: 1, QuotaLimit: 100}, - }, nil) // nil Redis + {Type: "brave", APIKey: "k", QuotaLimit: 100}, + }, nil) m.clientCache[""] = srv.Client() resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"}) @@ -102,51 +98,98 @@ func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) { func TestManager_GetUsage_NilRedis(t *testing.T) { m := NewManager(nil, nil) - used, err := m.GetUsage(context.Background(), "brave", "monthly") + used, err := m.GetUsage(context.Background(), "brave") require.NoError(t, err) require.Equal(t, int64(0), used) } func TestManager_GetAllUsage_NilRedis(t *testing.T) { m := NewManager([]ProviderConfig{ - {Type: "brave", QuotaRefreshInterval: "monthly"}, + {Type: "brave"}, }, nil) usage := m.GetAllUsage(context.Background()) require.Equal(t, int64(0), usage["brave"]) } -// --- Key/TTL helpers --- +// --- Quota TTL from subscription --- -func TestQuotaTTL_Daily(t *testing.T) { - require.Equal(t, 24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshDaily)) +func TestQuotaTTLFromSubscription_NilSubscription(t *testing.T) { + ttl := quotaTTLFromSubscription(nil) + require.Equal(t, defaultQuotaTTL, ttl) } -func TestQuotaTTL_Weekly(t *testing.T) { - require.Equal(t, 7*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshWeekly)) +func TestQuotaTTLFromSubscription_ZeroSubscription(t *testing.T) { + zero := int64(0) + ttl := quotaTTLFromSubscription(&zero) + require.Equal(t, defaultQuotaTTL, ttl) } -func TestQuotaTTL_Monthly(t *testing.T) { - require.Equal(t, 31*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshMonthly)) +func TestQuotaTTLFromSubscription_ValidSubscription(t *testing.T) { + // Subscribed 10 days ago — next reset in ~20 days + sub := time.Now().Add(-10 * 24 * time.Hour).Unix() + ttl := quotaTTLFromSubscription(&sub) + require.Greater(t, ttl, 15*24*time.Hour) // at least 15 days + require.Less(t, ttl, 25*24*time.Hour+quotaTTLBuffer) } -func TestPeriodKey_Daily(t *testing.T) { - key := periodKey(QuotaRefreshDaily) - require.Regexp(t, `^\d{4}-\d{2}-\d{2}$`, key) +func TestNextMonthlyReset_SubscribedRecentPast(t *testing.T) { + // Subscribed on the 10th of this month (always valid day) + now := time.Now().UTC() + sub := time.Date(now.Year(), now.Month(), 10, 0, 0, 0, 0, time.UTC) + next := nextMonthlyReset(sub) + require.True(t, next.After(now) || next.Equal(now), "next reset should be in the future or now") + require.True(t, next.Before(now.AddDate(0, 1, 1))) } -func TestPeriodKey_Weekly(t *testing.T) { - key := periodKey(QuotaRefreshWeekly) - require.Regexp(t, `^\d{4}-W\d{2}$`, key) +func TestNextMonthlyReset_SubscribedLongAgo(t *testing.T) { + // Subscribed 6 months ago on the 1st + sub := time.Now().UTC().AddDate(0, -6, 0) + sub = time.Date(sub.Year(), sub.Month(), 1, 0, 0, 0, 0, time.UTC) + next := nextMonthlyReset(sub) + require.True(t, next.After(time.Now().UTC())) + // Should be within the next 31 days + require.True(t, next.Before(time.Now().UTC().AddDate(0, 1, 1))) } -func TestPeriodKey_Monthly(t *testing.T) { - key := periodKey(QuotaRefreshMonthly) - require.Regexp(t, `^\d{4}-\d{2}$`, key) +func TestNextMonthlyReset_FutureSubscription(t *testing.T) { + sub := time.Now().UTC().AddDate(0, 0, 5) + next := nextMonthlyReset(sub) + require.True(t, next.After(time.Now().UTC())) } +func TestAddMonthsClamped_Jan31ToFeb(t *testing.T) { + sub := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC) + next := addMonthsClamped(sub, 1) + require.Equal(t, time.Month(2), next.Month()) + require.Equal(t, 28, next.Day()) // Feb 28 (2026 is not a leap year) +} + +func TestAddMonthsClamped_Jan31ToFebLeapYear(t *testing.T) { + sub := time.Date(2028, 1, 31, 0, 0, 0, 0, time.UTC) + next := addMonthsClamped(sub, 1) + require.Equal(t, time.Month(2), next.Month()) + require.Equal(t, 29, next.Day()) // Feb 29 (2028 is a leap year) +} + +func TestAddMonthsClamped_Mar31ToApr(t *testing.T) { + sub := time.Date(2026, 3, 31, 0, 0, 0, 0, time.UTC) + next := addMonthsClamped(sub, 1) + require.Equal(t, time.Month(4), next.Month()) + require.Equal(t, 30, next.Day()) // Apr has 30 days +} + +func TestAddMonthsClamped_NormalDay(t *testing.T) { + sub := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC) + next := addMonthsClamped(sub, 1) + require.Equal(t, time.Month(2), next.Month()) + require.Equal(t, 15, next.Day()) // no clamping needed +} + +// --- Redis key --- + func TestQuotaRedisKey_Format(t *testing.T) { - key := quotaRedisKey("brave", QuotaRefreshDaily) - require.Contains(t, key, "websearch:quota:brave:") + key := quotaRedisKey("brave") + require.Equal(t, "websearch:quota:brave", key) } // --- isProviderAvailable --- @@ -173,9 +216,7 @@ func TestIsProviderAvailable_Valid(t *testing.T) { func TestResolveProxyID_AccountProxyOverrides(t *testing.T) { cfg := ProviderConfig{ProxyID: 42} - // account proxy present → return 0 (account proxy has no config-level ID) require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080")) - // no account proxy → return provider's proxy ID require.Equal(t, int64(42), resolveProxyID(cfg, "")) } @@ -186,28 +227,23 @@ func TestIsProxyError_Nil(t *testing.T) { } func TestIsProxyError_ConnectionRefused(t *testing.T) { - err := fmt.Errorf("dial tcp: connection refused") - require.True(t, isProxyError(err)) + require.True(t, isProxyError(fmt.Errorf("dial tcp: connection refused"))) } func TestIsProxyError_Timeout(t *testing.T) { - err := fmt.Errorf("i/o timeout while connecting to proxy") - require.True(t, isProxyError(err)) + require.True(t, isProxyError(fmt.Errorf("i/o timeout while connecting to proxy"))) } func TestIsProxyError_SOCKS(t *testing.T) { - err := fmt.Errorf("socks connect failed") - require.True(t, isProxyError(err)) + require.True(t, isProxyError(fmt.Errorf("socks connect failed"))) } func TestIsProxyError_TLSHandshake(t *testing.T) { - err := fmt.Errorf("tls handshake timeout") - require.True(t, isProxyError(err)) + require.True(t, isProxyError(fmt.Errorf("tls handshake timeout"))) } func TestIsProxyError_APIError_NotProxy(t *testing.T) { - err := fmt.Errorf("API rate limit exceeded") - require.False(t, isProxyError(err)) + require.False(t, isProxyError(fmt.Errorf("API rate limit exceeded"))) } // --- isProxyAvailable (nil Redis) --- @@ -225,14 +261,13 @@ func TestIsProxyAvailable_ZeroID(t *testing.T) { // --- selectByQuotaWeight --- func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) { - m := NewManager(nil, nil) // nil Redis → GetUsage returns 0 + m := NewManager(nil, nil) candidates := []ProviderConfig{ - {Type: "brave", APIKey: "k1", QuotaLimit: 0}, // no limit → weight 0 - {Type: "tavily", APIKey: "k2", QuotaLimit: 100}, // remaining 100 + {Type: "brave", APIKey: "k1", QuotaLimit: 0}, + {Type: "tavily", APIKey: "k2", QuotaLimit: 100}, } result := m.selectByQuotaWeight(context.Background(), candidates) require.Len(t, result, 2) - // tavily (with quota) should come first require.Equal(t, "tavily", result[0].Type) require.Equal(t, "brave", result[1].Type) } @@ -245,7 +280,6 @@ func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) { } result := m.selectByQuotaWeight(context.Background(), candidates) require.Len(t, result, 2) - // both have weight 0, original order preserved } func TestSelectByQuotaWeight_Empty(t *testing.T) { diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index a8034e98..ba45c31b 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -2,12 +2,14 @@ package server import ( + "context" "log" "net/http" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -56,6 +58,34 @@ func ProvideRouter( } } + // Wire up websearch Manager builder so it initializes on startup and rebuilds on config save. + settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig) { + if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 { + service.SetWebSearchManager(nil) + return + } + configs := make([]websearch.ProviderConfig, 0, len(cfg.Providers)) + for _, p := range cfg.Providers { + if p.APIKey == "" { + continue + } + pc := websearch.ProviderConfig{ + Type: p.Type, + APIKey: p.APIKey, + QuotaLimit: p.QuotaLimit, + ExpiresAt: p.ExpiresAt, + } + if p.SubscribedAt != nil { + pc.SubscribedAt = p.SubscribedAt + } + if p.ProxyID != nil { + pc.ProxyID = *p.ProxyID + } + configs = append(configs, pc) + } + service.SetWebSearchManager(websearch.NewManager(configs, redisClient)) + }) + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 7c4e6cb7..0a7b7a8b 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -410,6 +410,7 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // Web Search 模拟配置 adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig) adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig) + adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation) } } diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go index 15ec1f9d..bdfff3e4 100644 --- a/backend/internal/service/websearch_config.go +++ b/backend/internal/service/websearch_config.go @@ -22,15 +22,14 @@ type WebSearchEmulationConfig struct { // WebSearchProviderConfig describes a single search provider (Brave or Tavily). type WebSearchProviderConfig struct { - Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily - APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses - APIKeyConfigured bool `json:"api_key_configured"` // read-only mask - Priority int `json:"priority"` // lower = higher priority - QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited - QuotaRefreshInterval string `json:"quota_refresh_interval"` // websearch.QuotaRefresh* - QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current period usage - ProxyID *int64 `json:"proxy_id"` // optional proxy association - ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp + Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily + APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses + APIKeyConfigured bool `json:"api_key_configured"` // read-only mask + QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited + SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly + QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current usage from Redis + ProxyID *int64 `json:"proxy_id"` // optional proxy association + ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp } // --- Validation --- @@ -42,13 +41,6 @@ var validProviderTypes = map[string]bool{ websearch.ProviderTypeTavily: true, } -var validQuotaIntervals = map[string]bool{ - websearch.QuotaRefreshDaily: true, - websearch.QuotaRefreshWeekly: true, - websearch.QuotaRefreshMonthly: true, - "": true, // defaults to monthly -} - func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error { if cfg == nil { return nil @@ -61,9 +53,6 @@ func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error { if !validProviderTypes[p.Type] { return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type) } - if !validQuotaIntervals[p.QuotaRefreshInterval] { - return fmt.Errorf("provider[%d]: invalid quota_refresh_interval %q", i, p.QuotaRefreshInterval) - } if p.QuotaLimit < 0 { return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i) } @@ -237,17 +226,55 @@ func (s *SettingService) RebuildWebSearchManager(ctx context.Context) { slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs)) } -// SanitizeWebSearchConfig returns a copy with api_key fields masked for API responses. -func SanitizeWebSearchConfig(cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig { +// WebSearchTestResult holds the result of a search test. +type WebSearchTestResult struct { + Provider string `json:"provider"` + Results []websearch.SearchResult `json:"results"` + Query string `json:"query"` +} + +// TestWebSearch executes a test search using the currently configured Manager. +// Uses Manager.TestSearch which bypasses quota tracking. +func TestWebSearch(ctx context.Context, query string) (*WebSearchTestResult, error) { + mgr := getWebSearchManager() + if mgr == nil { + return nil, fmt.Errorf("web search: manager not initialized, save config first") + } + resp, providerName, err := mgr.TestSearch(ctx, websearch.SearchRequest{ + Query: query, + MaxResults: webSearchDefaultMaxResults, + }) + if err != nil { + return nil, err + } + return &WebSearchTestResult{ + Provider: providerName, + Results: resp.Results, + Query: resp.Query, + }, nil +} + +// SanitizeWebSearchConfig returns a copy with api_key fields masked and quota usage populated. +func SanitizeWebSearchConfig(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig { if cfg == nil { return nil } out := *cfg out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers)) + + // Load usage from the global Manager (reads from Redis) + mgr := getWebSearchManager() + for i, p := range cfg.Providers { out.Providers[i] = p out.Providers[i].APIKeyConfigured = p.APIKey != "" out.Providers[i].APIKey = "" // never return the secret + + // Populate quota usage from Redis + if mgr != nil { + used, _ := mgr.GetUsage(ctx, p.Type) + out.Providers[i].QuotaUsed = used + } } return &out } diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go index 1a19dd9d..4aea98b7 100644 --- a/backend/internal/service/websearch_config_test.go +++ b/backend/internal/service/websearch_config_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -16,8 +17,8 @@ func TestValidateWebSearchConfig_Valid(t *testing.T) { cfg := &WebSearchEmulationConfig{ Enabled: true, Providers: []WebSearchProviderConfig{ - {Type: "brave", Priority: 1, QuotaLimit: 1000, QuotaRefreshInterval: "monthly"}, - {Type: "tavily", Priority: 2, QuotaLimit: 500, QuotaRefreshInterval: "daily"}, + {Type: "brave", QuotaLimit: 1000}, + {Type: "tavily", QuotaLimit: 500}, }, } require.NoError(t, validateWebSearchConfig(cfg)) @@ -39,13 +40,6 @@ func TestValidateWebSearchConfig_InvalidType(t *testing.T) { require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type") } -func TestValidateWebSearchConfig_InvalidQuotaInterval(t *testing.T) { - cfg := &WebSearchEmulationConfig{ - Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: "hourly"}}, - } - require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid quota_refresh_interval") -} - func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) { cfg := &WebSearchEmulationConfig{ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}}, @@ -56,20 +50,13 @@ func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) { func TestValidateWebSearchConfig_DuplicateType(t *testing.T) { cfg := &WebSearchEmulationConfig{ Providers: []WebSearchProviderConfig{ - {Type: "brave", Priority: 1}, - {Type: "brave", Priority: 2}, + {Type: "brave"}, + {Type: "brave"}, }, } require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type") } -func TestValidateWebSearchConfig_EmptyQuotaInterval(t *testing.T) { - cfg := &WebSearchEmulationConfig{ - Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: ""}}, - } - require.NoError(t, validateWebSearchConfig(cfg)) -} - func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) { cfg := &WebSearchEmulationConfig{ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}}, @@ -99,6 +86,15 @@ func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) { require.Empty(t, cfg.Providers) } +func TestParseWebSearchConfigJSON_BackwardCompatibility(t *testing.T) { + // Old config with priority and quota_refresh_interval should parse without error + raw := `{"enabled":true,"providers":[{"type":"brave","priority":1,"quota_refresh_interval":"monthly","quota_limit":1000}]}` + cfg := parseWebSearchConfigJSON(raw) + require.True(t, cfg.Enabled) + require.Len(t, cfg.Providers, 1) + require.Equal(t, int64(1000), cfg.Providers[0].QuotaLimit) +} + // --- SanitizeWebSearchConfig --- func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) { @@ -108,7 +104,7 @@ func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) { {Type: "brave", APIKey: "sk-secret-xxx"}, }, } - out := SanitizeWebSearchConfig(cfg) + out := SanitizeWebSearchConfig(context.Background(), cfg) require.Equal(t, "", out.Providers[0].APIKey) require.True(t, out.Providers[0].APIKeyConfigured) } @@ -117,25 +113,24 @@ func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) { cfg := &WebSearchEmulationConfig{ Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}}, } - out := SanitizeWebSearchConfig(cfg) + out := SanitizeWebSearchConfig(context.Background(), cfg) require.Equal(t, "", out.Providers[0].APIKey) require.False(t, out.Providers[0].APIKeyConfigured) } func TestSanitizeWebSearchConfig_Nil(t *testing.T) { - require.Nil(t, SanitizeWebSearchConfig(nil)) + require.Nil(t, SanitizeWebSearchConfig(context.Background(), nil)) } func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) { cfg := &WebSearchEmulationConfig{ Enabled: true, Providers: []WebSearchProviderConfig{ - {Type: "brave", APIKey: "secret", Priority: 10, QuotaLimit: 1000}, + {Type: "brave", APIKey: "secret", QuotaLimit: 1000}, }, } - out := SanitizeWebSearchConfig(cfg) + out := SanitizeWebSearchConfig(context.Background(), cfg) require.True(t, out.Enabled) - require.Equal(t, 10, out.Providers[0].Priority) require.Equal(t, int64(1000), out.Providers[0].QuotaLimit) } @@ -143,6 +138,6 @@ func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) { cfg := &WebSearchEmulationConfig{ Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}}, } - _ = SanitizeWebSearchConfig(cfg) + _ = SanitizeWebSearchConfig(context.Background(), cfg) require.Equal(t, "secret", cfg.Providers[0].APIKey) } diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index c6323b00..31284289 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -497,9 +497,8 @@ export interface WebSearchProviderConfig { type: 'brave' | 'tavily' api_key: string api_key_configured: boolean - priority: number quota_limit: number - quota_refresh_interval: 'daily' | 'weekly' | 'monthly' + subscribed_at: number | null quota_used?: number proxy_id: number | null expires_at: number | null @@ -510,6 +509,12 @@ export interface WebSearchEmulationConfig { providers: WebSearchProviderConfig[] } +export interface WebSearchTestResult { + provider: string + results: { url: string; title: string; snippet: string; page_age?: string }[] + query: string +} + export async function getWebSearchEmulationConfig(): Promise { const { data } = await apiClient.get( '/admin/settings/web-search-emulation' @@ -527,6 +532,16 @@ export async function updateWebSearchEmulationConfig( return data } +export async function testWebSearchEmulation( + query: string +): Promise { + const { data } = await apiClient.post( + '/admin/settings/web-search-emulation/test', + { query } + ) + return data +} + export const settingsAPI = { getSettings, updateSettings, @@ -544,7 +559,8 @@ export const settingsAPI = { getBetaPolicySettings, updateBetaPolicySettings, getWebSearchEmulationConfig, - updateWebSearchEmulationConfig + updateWebSearchEmulationConfig, + testWebSearchEmulation } export default settingsAPI diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 7119fa36..8e10bf2a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -4417,19 +4417,24 @@ export default { apiKey: 'API Key', apiKeyPlaceholder: 'Enter API Key', apiKeyConfigured: 'Configured', - priority: 'Priority', - priorityHint: 'Lower number = higher priority', + showApiKey: 'Show', + hideApiKey: 'Hide', + copyApiKey: 'Copy', + copied: 'Copied', quotaLimit: 'Quota Limit', quotaLimitHint: '0 = unlimited', - quotaRefreshInterval: 'Refresh Interval', - quotaUsed: 'Used', + subscribedAt: 'Subscribed At', + subscribedAtHint: 'Quota resets monthly from this date', + quotaUsage: 'Usage', proxy: 'Proxy', - expiresAt: 'Expires At', removeProvider: 'Remove', - daily: 'Daily', - weekly: 'Weekly', - monthly: 'Monthly', noProviders: 'No search providers configured', + test: 'Test', + testDefaultQuery: 'Major world events this year', + testing: 'Searching...', + testResultTitle: 'Search Results', + testResultProvider: 'Provider', + testNoResults: 'No results found', }, site: { title: 'Site Settings', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 6efaf657..1b82f419 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -4579,19 +4579,24 @@ export default { apiKey: 'API Key', apiKeyPlaceholder: '输入 API Key', apiKeyConfigured: '已配置', - priority: '优先级', - priorityHint: '数值越小优先级越高', + showApiKey: '显示', + hideApiKey: '隐藏', + copyApiKey: '复制', + copied: '已复制', quotaLimit: '配额上限', quotaLimitHint: '0 表示无限制', - quotaRefreshInterval: '刷新周期', - quotaUsed: '已使用', + subscribedAt: '订阅时间', + subscribedAtHint: '配额从此日期起每月自动重置', + quotaUsage: '用量', proxy: '代理', - expiresAt: '过期时间', removeProvider: '删除', - daily: '每日', - weekly: '每周', - monthly: '每月', noProviders: '未配置搜索服务商', + test: '测试', + testDefaultQuery: '搜索今年世界大事件', + testing: '搜索中...', + testResultTitle: '搜索结果', + testResultProvider: '服务商', + testNoResults: '无搜索结果', }, site: { title: '站点设置', diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index f57bfcf3..8ed77203 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1751,61 +1751,152 @@
-
- + + + {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }} + + + {{ t('admin.settings.webSearchEmulation.apiKeyConfigured') }} + +
+
-
+ +
+
- +
+ + + +
-
- - -

{{ t('admin.settings.webSearchEmulation.priorityHint') }}

-
-
- - -

{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}

-

- {{ t('admin.settings.webSearchEmulation.quotaUsed') }}: {{ provider.quota_used }} / {{ provider.quota_limit || '∞' }} -

-
-
- - +

{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}

+
+
+ + +

{{ t('admin.settings.webSearchEmulation.subscribedAtHint') }}

+
+
+ + +
+ {{ t('admin.settings.webSearchEmulation.quotaUsage') }}: +
+
+
+ {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }} +
+ + +
+ + +
+ + +
+
+ + +
+ +
+

+ {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }} +

+
+ {{ t('admin.settings.webSearchEmulation.testNoResults') }} +
+
+ {{ r.title }} +

{{ r.snippet && r.snippet.length > 120 ? r.snippet.slice(0, 120) + '...' : r.snippet }}

+
+
+
@@ -2303,6 +2394,13 @@ ]" >{{ pt.label }} +

+ {{ t('admin.settings.payment.enabledPaymentTypesHint') }} + + {{ t('admin.settings.payment.findProvider') }} + + +

@@ -2562,60 +2660,6 @@
- -
-
-

- {{ t('admin.settings.balanceNotify.title') }} -

-

- {{ t('admin.settings.balanceNotify.description') }} -

-
-
-
- - -
-
- -
- $ - -
-

{{ t('admin.settings.balanceNotify.thresholdHint') }}

-
-
-
- - -
-
-

- {{ t('admin.settings.quotaNotify.title') }} -

-

- {{ t('admin.settings.quotaNotify.description') }} -

-
-
-
- -
-
- - -
- -
-

{{ t('admin.settings.quotaNotify.emailsHint') }}

-
-
-
@@ -2674,8 +2718,9 @@ import type { DefaultSubscriptionSetting, WebSearchEmulationConfig, WebSearchProviderConfig, + WebSearchTestResult, } from '@/api/admin/settings' -import type { AdminGroup } from '@/types' +import type { AdminGroup, Proxy } from '@/types' import type { ProviderInstance } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' import Icon from '@/components/icons/Icon.vue' @@ -2894,13 +2939,12 @@ const form = reactive({ // Gateway forwarding behavior enable_fingerprint_unification: true, enable_metadata_passthrough: false, - enable_cch_signing: false, - // Balance & quota notification - balance_low_notify_enabled: false, - balance_low_notify_threshold: 0, - account_quota_notify_emails: [] as string[] + enable_cch_signing: false }) +// Proxies for web search emulation ProxySelector +const webSearchProxies = ref([]) + // Web Search Emulation config (loaded/saved separately) const DEFAULT_WEB_SEARCH_QUOTA_LIMIT = 1000 @@ -2909,26 +2953,101 @@ const webSearchConfig = reactive({ providers: [], }) +const expandedProviders = reactive>({}) +const apiKeyVisible = reactive>({}) +const wsTestQuery = ref('') +const wsTestLoading = ref(false) +const wsTestResult = ref(null) + +function toggleProviderExpand(idx: number) { + expandedProviders[idx] = !expandedProviders[idx] +} + +function removeWebSearchProvider(idx: number) { + webSearchConfig.providers.splice(idx, 1) + // Re-index expandedProviders and apiKeyVisible after removal + const newExpanded: Record = {} + const newVisible: Record = {} + for (let i = 0; i < webSearchConfig.providers.length; i++) { + const oldIdx = i >= idx ? i + 1 : i + newExpanded[i] = expandedProviders[oldIdx] ?? false + newVisible[i] = apiKeyVisible[oldIdx] ?? false + } + Object.keys(expandedProviders).forEach((k) => delete expandedProviders[Number(k)]) + Object.keys(apiKeyVisible).forEach((k) => delete apiKeyVisible[Number(k)]) + Object.assign(expandedProviders, newExpanded) + Object.assign(apiKeyVisible, newVisible) +} + function addWebSearchProvider() { + const idx = webSearchConfig.providers.length webSearchConfig.providers.push({ type: 'brave', api_key: '', api_key_configured: false, - priority: webSearchConfig.providers.length + 1, quota_limit: DEFAULT_WEB_SEARCH_QUOTA_LIMIT, - quota_refresh_interval: 'monthly', + subscribed_at: null, proxy_id: null, expires_at: null, } as WebSearchProviderConfig) + expandedProviders[idx] = true +} + +function formatSubscribedAt(ts: number | null): string { + if (!ts) return '' + // Use UTC to avoid timezone drift on repeated edits + const d = new Date(ts * 1000) + const y = d.getUTCFullYear() + const m = String(d.getUTCMonth() + 1).padStart(2, '0') + const day = String(d.getUTCDate()).padStart(2, '0') + return `${y}-${m}-${day}` +} + +function parseSubscribedAt(dateStr: string): number | null { + if (!dateStr) return null + // Parse as UTC to match formatSubscribedAt + return Math.floor(new Date(dateStr + 'T00:00:00Z').getTime() / 1000) +} + +function quotaPercentage(provider: WebSearchProviderConfig): number { + if (!provider.quota_limit || provider.quota_limit <= 0) return 0 + return ((provider.quota_used ?? 0) / provider.quota_limit) * 100 +} + +async function copyApiKey(idx: number) { + const key = webSearchConfig.providers[idx]?.api_key + if (!key) { + appStore.showError(t('admin.settings.webSearchEmulation.apiKeyPlaceholder')) + return + } + await navigator.clipboard.writeText(key) + appStore.showSuccess(t('admin.settings.webSearchEmulation.copied')) +} + +async function testWebSearchProvider() { + wsTestLoading.value = true + wsTestResult.value = null + try { + const query = wsTestQuery.value.trim() || t('admin.settings.webSearchEmulation.testDefaultQuery') + wsTestResult.value = await adminAPI.settings.testWebSearchEmulation(query) + } catch (err: unknown) { + appStore.showError(extractApiErrorMessage(err, t('common.error'))) + } finally { + wsTestLoading.value = false + } } async function loadWebSearchConfig() { try { - const resp = await adminAPI.settings.getWebSearchEmulationConfig() + const [resp, proxiesResp] = await Promise.all([ + adminAPI.settings.getWebSearchEmulationConfig(), + adminAPI.proxies.list().catch(() => ({ items: [] as Proxy[] })), + ]) if (resp) { webSearchConfig.enabled = resp.enabled || false webSearchConfig.providers = resp.providers || [] } + webSearchProxies.value = proxiesResp.items || [] } catch (err: unknown) { // 404 is expected when config hasn't been created yet; show error for other failures const status = (err as { status?: number })?.status @@ -3030,14 +3149,6 @@ function handleRegistrationEmailSuffixWhitelistPaste(event: ClipboardEvent) { } } -// Quota notify email helpers -const addQuotaNotifyEmail = () => { - if (!form.account_quota_notify_emails) { - form.account_quota_notify_emails = [] - } - form.account_quota_notify_emails.push('') -} - // LinuxDo OAuth redirect URL suggestion const linuxdoRedirectUrlSuggestion = computed(() => { if (typeof window === 'undefined') return '' @@ -3377,10 +3488,6 @@ async function saveSettings() { payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1, payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit, payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode, - // Balance & quota notification - balance_low_notify_enabled: form.balance_low_notify_enabled, - balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, - account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''), } const updated = await adminAPI.settings.updateSettings(payload) From f694afbbf431122407fe7f4bfb3d7f4ee0e5654f Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 13:53:02 +0800 Subject: [PATCH 24/88] feat(notify): add percentage threshold type for balance low notification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add threshold_type field (fixed/percentage) to system and user settings - Add total_recharged field to users table, auto-incremented on balance credit - Percentage mode: effective threshold = total_recharged × percentage / 100 - User-level threshold_type inherits from system default when not set - Update admin settings UI with radio selector (fixed amount / percentage) - Migration: 102_add_balance_notify_threshold_type.sql --- backend/ent/migrate/schema.go | 2 + backend/ent/mutation.go | 143 ++++++++++++++++- backend/ent/runtime/runtime.go | 10 +- backend/ent/schema/user.go | 5 + backend/ent/user.go | 26 ++- backend/ent/user/user.go | 20 +++ backend/ent/user/where.go | 115 ++++++++++++++ backend/ent/user_create.go | 150 ++++++++++++++++++ backend/ent/user_update.go | 88 ++++++++++ .../internal/handler/admin/setting_handler.go | 38 +++++ backend/internal/handler/dto/mappers.go | 28 ++-- backend/internal/handler/dto/settings.go | 7 +- backend/internal/handler/dto/types.go | 8 +- backend/internal/handler/user_handler.go | 14 +- backend/internal/repository/api_key_repo.go | 34 ++-- backend/internal/repository/user_repo.go | 8 +- .../service/balance_notify_service.go | 51 ++++-- backend/internal/service/domain_constants.go | 9 +- backend/internal/service/setting_service.go | 9 ++ backend/internal/service/settings_view.go | 5 +- backend/internal/service/user.go | 8 +- backend/internal/service/user_service.go | 14 +- .../102_add_balance_notify_threshold_type.sql | 4 + frontend/src/api/admin/settings.ts | 2 + frontend/src/i18n/locales/en.ts | 4 + frontend/src/i18n/locales/zh.ts | 6 +- frontend/src/views/admin/SettingsView.vue | 104 +++++++++++- 27 files changed, 838 insertions(+), 74 deletions(-) create mode 100644 backend/migrations/102_add_balance_notify_threshold_type.sql diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 4f31883b..1fff61ba 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -1079,8 +1079,10 @@ var ( {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true}, + {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"}, {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index cdaf363a..3bca248d 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -28211,9 +28211,12 @@ type UserMutation struct { totp_enabled *bool totp_enabled_at *time.Time balance_notify_enabled *bool + balance_notify_threshold_type *string balance_notify_threshold *float64 addbalance_notify_threshold *float64 balance_notify_extra_emails *string + total_recharged *float64 + addtotal_recharged *float64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -28967,6 +28970,42 @@ func (m *UserMutation) ResetBalanceNotifyEnabled() { m.balance_notify_enabled = nil } +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (m *UserMutation) SetBalanceNotifyThresholdType(s string) { + m.balance_notify_threshold_type = &s +} + +// BalanceNotifyThresholdType returns the value of the "balance_notify_threshold_type" field in the mutation. +func (m *UserMutation) BalanceNotifyThresholdType() (r string, exists bool) { + v := m.balance_notify_threshold_type + if v == nil { + return + } + return *v, true +} + +// OldBalanceNotifyThresholdType returns the old "balance_notify_threshold_type" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldBalanceNotifyThresholdType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBalanceNotifyThresholdType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBalanceNotifyThresholdType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBalanceNotifyThresholdType: %w", err) + } + return oldValue.BalanceNotifyThresholdType, nil +} + +// ResetBalanceNotifyThresholdType resets all changes to the "balance_notify_threshold_type" field. +func (m *UserMutation) ResetBalanceNotifyThresholdType() { + m.balance_notify_threshold_type = nil +} + // SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. func (m *UserMutation) SetBalanceNotifyThreshold(f float64) { m.balance_notify_threshold = &f @@ -29073,6 +29112,62 @@ func (m *UserMutation) ResetBalanceNotifyExtraEmails() { m.balance_notify_extra_emails = nil } +// SetTotalRecharged sets the "total_recharged" field. +func (m *UserMutation) SetTotalRecharged(f float64) { + m.total_recharged = &f + m.addtotal_recharged = nil +} + +// TotalRecharged returns the value of the "total_recharged" field in the mutation. +func (m *UserMutation) TotalRecharged() (r float64, exists bool) { + v := m.total_recharged + if v == nil { + return + } + return *v, true +} + +// OldTotalRecharged returns the old "total_recharged" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldTotalRecharged(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotalRecharged is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotalRecharged requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotalRecharged: %w", err) + } + return oldValue.TotalRecharged, nil +} + +// AddTotalRecharged adds f to the "total_recharged" field. +func (m *UserMutation) AddTotalRecharged(f float64) { + if m.addtotal_recharged != nil { + *m.addtotal_recharged += f + } else { + m.addtotal_recharged = &f + } +} + +// AddedTotalRecharged returns the value that was added to the "total_recharged" field in this mutation. +func (m *UserMutation) AddedTotalRecharged() (r float64, exists bool) { + v := m.addtotal_recharged + if v == nil { + return + } + return *v, true +} + +// ResetTotalRecharged resets all changes to the "total_recharged" field. +func (m *UserMutation) ResetTotalRecharged() { + m.total_recharged = nil + m.addtotal_recharged = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -29647,7 +29742,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 17) + fields := make([]string, 0, 19) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -29693,12 +29788,18 @@ func (m *UserMutation) Fields() []string { if m.balance_notify_enabled != nil { fields = append(fields, user.FieldBalanceNotifyEnabled) } + if m.balance_notify_threshold_type != nil { + fields = append(fields, user.FieldBalanceNotifyThresholdType) + } if m.balance_notify_threshold != nil { fields = append(fields, user.FieldBalanceNotifyThreshold) } if m.balance_notify_extra_emails != nil { fields = append(fields, user.FieldBalanceNotifyExtraEmails) } + if m.total_recharged != nil { + fields = append(fields, user.FieldTotalRecharged) + } return fields } @@ -29737,10 +29838,14 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabledAt() case user.FieldBalanceNotifyEnabled: return m.BalanceNotifyEnabled() + case user.FieldBalanceNotifyThresholdType: + return m.BalanceNotifyThresholdType() case user.FieldBalanceNotifyThreshold: return m.BalanceNotifyThreshold() case user.FieldBalanceNotifyExtraEmails: return m.BalanceNotifyExtraEmails() + case user.FieldTotalRecharged: + return m.TotalRecharged() } return nil, false } @@ -29780,10 +29885,14 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabledAt(ctx) case user.FieldBalanceNotifyEnabled: return m.OldBalanceNotifyEnabled(ctx) + case user.FieldBalanceNotifyThresholdType: + return m.OldBalanceNotifyThresholdType(ctx) case user.FieldBalanceNotifyThreshold: return m.OldBalanceNotifyThreshold(ctx) case user.FieldBalanceNotifyExtraEmails: return m.OldBalanceNotifyExtraEmails(ctx) + case user.FieldTotalRecharged: + return m.OldTotalRecharged(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -29898,6 +30007,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetBalanceNotifyEnabled(v) return nil + case user.FieldBalanceNotifyThresholdType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBalanceNotifyThresholdType(v) + return nil case user.FieldBalanceNotifyThreshold: v, ok := value.(float64) if !ok { @@ -29912,6 +30028,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetBalanceNotifyExtraEmails(v) return nil + case user.FieldTotalRecharged: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotalRecharged(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -29929,6 +30052,9 @@ func (m *UserMutation) AddedFields() []string { if m.addbalance_notify_threshold != nil { fields = append(fields, user.FieldBalanceNotifyThreshold) } + if m.addtotal_recharged != nil { + fields = append(fields, user.FieldTotalRecharged) + } return fields } @@ -29943,6 +30069,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedConcurrency() case user.FieldBalanceNotifyThreshold: return m.AddedBalanceNotifyThreshold() + case user.FieldTotalRecharged: + return m.AddedTotalRecharged() } return nil, false } @@ -29973,6 +30101,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddBalanceNotifyThreshold(v) return nil + case user.FieldTotalRecharged: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTotalRecharged(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -30072,12 +30207,18 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldBalanceNotifyEnabled: m.ResetBalanceNotifyEnabled() return nil + case user.FieldBalanceNotifyThresholdType: + m.ResetBalanceNotifyThresholdType() + return nil case user.FieldBalanceNotifyThreshold: m.ResetBalanceNotifyThreshold() return nil case user.FieldBalanceNotifyExtraEmails: m.ResetBalanceNotifyExtraEmails() return nil + case user.FieldTotalRecharged: + m.ResetTotalRecharged() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index a288f5d9..951b5f99 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -1297,10 +1297,18 @@ func init() { userDescBalanceNotifyEnabled := userFields[11].Descriptor() // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field. user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool) + // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field. + userDescBalanceNotifyThresholdType := userFields[12].Descriptor() + // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field. + user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string) // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field. - userDescBalanceNotifyExtraEmails := userFields[13].Descriptor() + userDescBalanceNotifyExtraEmails := userFields[14].Descriptor() // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field. user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string) + // userDescTotalRecharged is the schema descriptor for total_recharged field. + userDescTotalRecharged := userFields[15].Descriptor() + // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. + user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index bdaa4509..ef52e985 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -76,6 +76,8 @@ func (User) Fields() []ent.Field { // 余额不足通知 field.Bool("balance_notify_enabled"). Default(true), + field.String("balance_notify_threshold_type"). + Default("fixed"), // "fixed" | "percentage" field.Float("balance_notify_threshold"). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). Optional(). @@ -83,6 +85,9 @@ func (User) Fields() []ent.Field { field.String("balance_notify_extra_emails"). SchemaType(map[string]string{dialect.Postgres: "text"}). Default("[]"), + field.Float("total_recharged"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0), } } diff --git a/backend/ent/user.go b/backend/ent/user.go index fc4ddb8f..9fa91f74 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -47,10 +47,14 @@ type User struct { TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field. BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"` + // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field. + BalanceNotifyThresholdType string `json:"balance_notify_threshold_type,omitempty"` // BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field. BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` // BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field. BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"` + // TotalRecharged holds the value of the "total_recharged" field. + TotalRecharged float64 `json:"total_recharged,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -192,11 +196,11 @@ func (*User) scanValues(columns []string) ([]any, error) { switch columns[i] { case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled: values[i] = new(sql.NullBool) - case user.FieldBalance, user.FieldBalanceNotifyThreshold: + case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged: values[i] = new(sql.NullFloat64) case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyExtraEmails: + case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: values[i] = new(sql.NullTime) @@ -314,6 +318,12 @@ func (_m *User) assignValues(columns []string, values []any) error { } else if value.Valid { _m.BalanceNotifyEnabled = value.Bool } + case user.FieldBalanceNotifyThresholdType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field balance_notify_threshold_type", values[i]) + } else if value.Valid { + _m.BalanceNotifyThresholdType = value.String + } case user.FieldBalanceNotifyThreshold: if value, ok := values[i].(*sql.NullFloat64); !ok { return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i]) @@ -327,6 +337,12 @@ func (_m *User) assignValues(columns []string, values []any) error { } else if value.Valid { _m.BalanceNotifyExtraEmails = value.String } + case user.FieldTotalRecharged: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field total_recharged", values[i]) + } else if value.Valid { + _m.TotalRecharged = value.Float64 + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -469,6 +485,9 @@ func (_m *User) String() string { builder.WriteString("balance_notify_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled)) builder.WriteString(", ") + builder.WriteString("balance_notify_threshold_type=") + builder.WriteString(_m.BalanceNotifyThresholdType) + builder.WriteString(", ") if v := _m.BalanceNotifyThreshold; v != nil { builder.WriteString("balance_notify_threshold=") builder.WriteString(fmt.Sprintf("%v", *v)) @@ -476,6 +495,9 @@ func (_m *User) String() string { builder.WriteString(", ") builder.WriteString("balance_notify_extra_emails=") builder.WriteString(_m.BalanceNotifyExtraEmails) + builder.WriteString(", ") + builder.WriteString("total_recharged=") + builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index aff37013..d88a3a38 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -45,10 +45,14 @@ const ( FieldTotpEnabledAt = "totp_enabled_at" // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database. FieldBalanceNotifyEnabled = "balance_notify_enabled" + // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database. + FieldBalanceNotifyThresholdType = "balance_notify_threshold_type" // FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database. FieldBalanceNotifyThreshold = "balance_notify_threshold" // FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database. FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails" + // FieldTotalRecharged holds the string denoting the total_recharged field in the database. + FieldTotalRecharged = "total_recharged" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -168,8 +172,10 @@ var Columns = []string{ FieldTotpEnabled, FieldTotpEnabledAt, FieldBalanceNotifyEnabled, + FieldBalanceNotifyThresholdType, FieldBalanceNotifyThreshold, FieldBalanceNotifyExtraEmails, + FieldTotalRecharged, } var ( @@ -228,8 +234,12 @@ var ( DefaultTotpEnabled bool // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field. DefaultBalanceNotifyEnabled bool + // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field. + DefaultBalanceNotifyThresholdType string // DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field. DefaultBalanceNotifyExtraEmails string + // DefaultTotalRecharged holds the default value on creation for the "total_recharged" field. + DefaultTotalRecharged float64 ) // OrderOption defines the ordering options for the User queries. @@ -315,6 +325,11 @@ func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc() } +// ByBalanceNotifyThresholdType orders the results by the balance_notify_threshold_type field. +func ByBalanceNotifyThresholdType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBalanceNotifyThresholdType, opts...).ToFunc() +} + // ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field. func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc() @@ -325,6 +340,11 @@ func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc() } +// ByTotalRecharged orders the results by the total_recharged field. +func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 11a0318f..2788aa7a 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -130,6 +130,11 @@ func BalanceNotifyEnabled(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) } +// BalanceNotifyThresholdType applies equality check predicate on the "balance_notify_threshold_type" field. It's identical to BalanceNotifyThresholdTypeEQ. +func BalanceNotifyThresholdType(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v)) +} + // BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ. func BalanceNotifyThreshold(v float64) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v)) @@ -140,6 +145,11 @@ func BalanceNotifyExtraEmails(v string) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v)) } +// TotalRecharged applies equality check predicate on the "total_recharged" field. It's identical to TotalRechargedEQ. +func TotalRecharged(v float64) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotalRecharged, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -885,6 +895,71 @@ func BalanceNotifyEnabledNEQ(v bool) predicate.User { return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v)) } +// BalanceNotifyThresholdTypeEQ applies the EQ predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeNEQ applies the NEQ predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeIn applies the In predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldBalanceNotifyThresholdType, vs...)) +} + +// BalanceNotifyThresholdTypeNotIn applies the NotIn predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThresholdType, vs...)) +} + +// BalanceNotifyThresholdTypeGT applies the GT predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeGTE applies the GTE predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeLT applies the LT predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeLTE applies the LTE predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeContains applies the Contains predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeHasPrefix applies the HasPrefix predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeHasSuffix applies the HasSuffix predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeEqualFold applies the EqualFold predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeContainsFold applies the ContainsFold predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyThresholdType, v)) +} + // BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field. func BalanceNotifyThresholdEQ(v float64) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v)) @@ -1000,6 +1075,46 @@ func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User { return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v)) } +// TotalRechargedEQ applies the EQ predicate on the "total_recharged" field. +func TotalRechargedEQ(v float64) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotalRecharged, v)) +} + +// TotalRechargedNEQ applies the NEQ predicate on the "total_recharged" field. +func TotalRechargedNEQ(v float64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldTotalRecharged, v)) +} + +// TotalRechargedIn applies the In predicate on the "total_recharged" field. +func TotalRechargedIn(vs ...float64) predicate.User { + return predicate.User(sql.FieldIn(FieldTotalRecharged, vs...)) +} + +// TotalRechargedNotIn applies the NotIn predicate on the "total_recharged" field. +func TotalRechargedNotIn(vs ...float64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldTotalRecharged, vs...)) +} + +// TotalRechargedGT applies the GT predicate on the "total_recharged" field. +func TotalRechargedGT(v float64) predicate.User { + return predicate.User(sql.FieldGT(FieldTotalRecharged, v)) +} + +// TotalRechargedGTE applies the GTE predicate on the "total_recharged" field. +func TotalRechargedGTE(v float64) predicate.User { + return predicate.User(sql.FieldGTE(FieldTotalRecharged, v)) +} + +// TotalRechargedLT applies the LT predicate on the "total_recharged" field. +func TotalRechargedLT(v float64) predicate.User { + return predicate.User(sql.FieldLT(FieldTotalRecharged, v)) +} + +// TotalRechargedLTE applies the LTE predicate on the "total_recharged" field. +func TotalRechargedLTE(v float64) predicate.User { + return predicate.User(sql.FieldLTE(FieldTotalRecharged, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index 955fde72..fbc64f9c 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -225,6 +225,20 @@ func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate { return _c } +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (_c *UserCreate) SetBalanceNotifyThresholdType(v string) *UserCreate { + _c.mutation.SetBalanceNotifyThresholdType(v) + return _c +} + +// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil. +func (_c *UserCreate) SetNillableBalanceNotifyThresholdType(v *string) *UserCreate { + if v != nil { + _c.SetBalanceNotifyThresholdType(*v) + } + return _c +} + // SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate { _c.mutation.SetBalanceNotifyThreshold(v) @@ -253,6 +267,20 @@ func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate return _c } +// SetTotalRecharged sets the "total_recharged" field. +func (_c *UserCreate) SetTotalRecharged(v float64) *UserCreate { + _c.mutation.SetTotalRecharged(v) + return _c +} + +// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil. +func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate { + if v != nil { + _c.SetTotalRecharged(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -486,10 +514,18 @@ func (_c *UserCreate) defaults() error { v := user.DefaultBalanceNotifyEnabled _c.mutation.SetBalanceNotifyEnabled(v) } + if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok { + v := user.DefaultBalanceNotifyThresholdType + _c.mutation.SetBalanceNotifyThresholdType(v) + } if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok { v := user.DefaultBalanceNotifyExtraEmails _c.mutation.SetBalanceNotifyExtraEmails(v) } + if _, ok := _c.mutation.TotalRecharged(); !ok { + v := user.DefaultTotalRecharged + _c.mutation.SetTotalRecharged(v) + } return nil } @@ -556,9 +592,15 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)} } + if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok { + return &ValidationError{Name: "balance_notify_threshold_type", err: errors.New(`ent: missing required field "User.balance_notify_threshold_type"`)} + } if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok { return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)} } + if _, ok := _c.mutation.TotalRecharged(); !ok { + return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)} + } return nil } @@ -646,6 +688,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _node.BalanceNotifyEnabled = value } + if value, ok := _c.mutation.BalanceNotifyThresholdType(); ok { + _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value) + _node.BalanceNotifyThresholdType = value + } if value, ok := _c.mutation.BalanceNotifyThreshold(); ok { _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value) _node.BalanceNotifyThreshold = &value @@ -654,6 +700,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value) _node.BalanceNotifyExtraEmails = value } + if value, ok := _c.mutation.TotalRecharged(); ok { + _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) + _node.TotalRecharged = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1068,6 +1118,18 @@ func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert { return u } +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (u *UserUpsert) SetBalanceNotifyThresholdType(v string) *UserUpsert { + u.Set(user.FieldBalanceNotifyThresholdType, v) + return u +} + +// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create. +func (u *UserUpsert) UpdateBalanceNotifyThresholdType() *UserUpsert { + u.SetExcluded(user.FieldBalanceNotifyThresholdType) + return u +} + // SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert { u.Set(user.FieldBalanceNotifyThreshold, v) @@ -1104,6 +1166,24 @@ func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert { return u } +// SetTotalRecharged sets the "total_recharged" field. +func (u *UserUpsert) SetTotalRecharged(v float64) *UserUpsert { + u.Set(user.FieldTotalRecharged, v) + return u +} + +// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create. +func (u *UserUpsert) UpdateTotalRecharged() *UserUpsert { + u.SetExcluded(user.FieldTotalRecharged) + return u +} + +// AddTotalRecharged adds v to the "total_recharged" field. +func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert { + u.Add(user.FieldTotalRecharged, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1380,6 +1460,20 @@ func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne { }) } +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (u *UserUpsertOne) SetBalanceNotifyThresholdType(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyThresholdType(v) + }) +} + +// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateBalanceNotifyThresholdType() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyThresholdType() + }) +} + // SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne { return u.Update(func(s *UserUpsert) { @@ -1422,6 +1516,27 @@ func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne { }) } +// SetTotalRecharged sets the "total_recharged" field. +func (u *UserUpsertOne) SetTotalRecharged(v float64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetTotalRecharged(v) + }) +} + +// AddTotalRecharged adds v to the "total_recharged" field. +func (u *UserUpsertOne) AddTotalRecharged(v float64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddTotalRecharged(v) + }) +} + +// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateTotalRecharged() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1864,6 +1979,20 @@ func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk { }) } +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (u *UserUpsertBulk) SetBalanceNotifyThresholdType(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyThresholdType(v) + }) +} + +// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateBalanceNotifyThresholdType() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyThresholdType() + }) +} + // SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk { return u.Update(func(s *UserUpsert) { @@ -1906,6 +2035,27 @@ func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk { }) } +// SetTotalRecharged sets the "total_recharged" field. +func (u *UserUpsertBulk) SetTotalRecharged(v float64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetTotalRecharged(v) + }) +} + +// AddTotalRecharged adds v to the "total_recharged" field. +func (u *UserUpsertBulk) AddTotalRecharged(v float64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddTotalRecharged(v) + }) +} + +// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateTotalRecharged() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 823df0b6..6b355247 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -257,6 +257,20 @@ func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate { return _u } +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (_u *UserUpdate) SetBalanceNotifyThresholdType(v string) *UserUpdate { + _u.mutation.SetBalanceNotifyThresholdType(v) + return _u +} + +// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil. +func (_u *UserUpdate) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdate { + if v != nil { + _u.SetBalanceNotifyThresholdType(*v) + } + return _u +} + // SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate { _u.mutation.ResetBalanceNotifyThreshold() @@ -298,6 +312,27 @@ func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate return _u } +// SetTotalRecharged sets the "total_recharged" field. +func (_u *UserUpdate) SetTotalRecharged(v float64) *UserUpdate { + _u.mutation.ResetTotalRecharged() + _u.mutation.SetTotalRecharged(v) + return _u +} + +// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil. +func (_u *UserUpdate) SetNillableTotalRecharged(v *float64) *UserUpdate { + if v != nil { + _u.SetTotalRecharged(*v) + } + return _u +} + +// AddTotalRecharged adds value to the "total_recharged" field. +func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate { + _u.mutation.AddTotalRecharged(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -804,6 +839,9 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok { + _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value) + } if value, ok := _u.mutation.BalanceNotifyThreshold(); ok { _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value) } @@ -816,6 +854,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok { _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value) } + if value, ok := _u.mutation.TotalRecharged(); ok { + _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedTotalRecharged(); ok { + _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1518,6 +1562,20 @@ func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne return _u } +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (_u *UserUpdateOne) SetBalanceNotifyThresholdType(v string) *UserUpdateOne { + _u.mutation.SetBalanceNotifyThresholdType(v) + return _u +} + +// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdateOne { + if v != nil { + _u.SetBalanceNotifyThresholdType(*v) + } + return _u +} + // SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne { _u.mutation.ResetBalanceNotifyThreshold() @@ -1559,6 +1617,27 @@ func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpd return _u } +// SetTotalRecharged sets the "total_recharged" field. +func (_u *UserUpdateOne) SetTotalRecharged(v float64) *UserUpdateOne { + _u.mutation.ResetTotalRecharged() + _u.mutation.SetTotalRecharged(v) + return _u +} + +// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableTotalRecharged(v *float64) *UserUpdateOne { + if v != nil { + _u.SetTotalRecharged(*v) + } + return _u +} + +// AddTotalRecharged adds value to the "total_recharged" field. +func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne { + _u.mutation.AddTotalRecharged(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2095,6 +2174,9 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok { + _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value) + } if value, ok := _u.mutation.BalanceNotifyThreshold(); ok { _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value) } @@ -2107,6 +2189,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok { _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value) } + if value, ok := _u.mutation.TotalRecharged(); ok { + _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedTotalRecharged(); ok { + _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index e5e024c6..3d587a21 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -176,6 +176,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableCCHSigning: settings.EnableCCHSigning, WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, + BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, + BalanceLowNotifyThresholdType: settings.BalanceLowNotifyThresholdType, + BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, + AccountQuotaNotifyEmails: settings.AccountQuotaNotifyEmails, PaymentEnabled: paymentCfg.Enabled, PaymentMinAmount: paymentCfg.MinAmount, PaymentMaxAmount: paymentCfg.MaxAmount, @@ -305,6 +309,12 @@ type UpdateSettingsRequest struct { EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableCCHSigning *bool `json:"enable_cch_signing"` + // Balance low notification + BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThresholdType *string `json:"balance_low_notify_threshold_type"` + BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` + AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"` + // Payment configuration (integrated into settings, full replace) PaymentEnabled *bool `json:"payment_enabled"` PaymentMinAmount *float64 `json:"payment_min_amount"` @@ -882,6 +892,30 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), + BalanceLowNotifyEnabled: func() bool { + if req.BalanceLowNotifyEnabled != nil { + return *req.BalanceLowNotifyEnabled + } + return previousSettings.BalanceLowNotifyEnabled + }(), + BalanceLowNotifyThresholdType: func() string { + if req.BalanceLowNotifyThresholdType != nil { + return *req.BalanceLowNotifyThresholdType + } + return previousSettings.BalanceLowNotifyThresholdType + }(), + BalanceLowNotifyThreshold: func() float64 { + if req.BalanceLowNotifyThreshold != nil { + return *req.BalanceLowNotifyThreshold + } + return previousSettings.BalanceLowNotifyThreshold + }(), + AccountQuotaNotifyEmails: func() []string { + if req.AccountQuotaNotifyEmails != nil { + return *req.AccountQuotaNotifyEmails + } + return previousSettings.AccountQuotaNotifyEmails + }(), } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { @@ -1028,6 +1062,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableCCHSigning: updatedSettings.EnableCCHSigning, + BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, + BalanceLowNotifyThresholdType: updatedSettings.BalanceLowNotifyThresholdType, + BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, + AccountQuotaNotifyEmails: updatedSettings.AccountQuotaNotifyEmails, PaymentEnabled: updatedPaymentCfg.Enabled, PaymentMinAmount: updatedPaymentCfg.MinAmount, PaymentMaxAmount: updatedPaymentCfg.MaxAmount, diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index a465c7fb..147072c3 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -13,19 +13,21 @@ func UserFromServiceShallow(u *service.User) *User { return nil } return &User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - AllowedGroups: u.AllowedGroups, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, - BalanceNotifyEnabled: u.BalanceNotifyEnabled, - BalanceNotifyThreshold: u.BalanceNotifyThreshold, - BalanceNotifyExtraEmails: u.BalanceNotifyExtraEmails, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + AllowedGroups: u.AllowedGroups, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + BalanceNotifyEnabled: u.BalanceNotifyEnabled, + BalanceNotifyThresholdType: u.BalanceNotifyThresholdType, + BalanceNotifyThreshold: u.BalanceNotifyThreshold, + BalanceNotifyExtraEmails: u.BalanceNotifyExtraEmails, + TotalRecharged: u.TotalRecharged, } } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index e29f72da..8da7c6f2 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -150,9 +150,10 @@ type SystemSettings struct { PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"` // Balance low notification - BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` - AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThresholdType string `json:"balance_low_notify_threshold_type"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 18522868..425d3df9 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -19,9 +19,11 @@ type User struct { UpdatedAt time.Time `json:"updated_at"` // 余额不足通知 - BalanceNotifyEnabled bool `json:"balance_notify_enabled"` - BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` - BalanceNotifyExtraEmails []string `json:"balance_notify_extra_emails"` + BalanceNotifyEnabled bool `json:"balance_notify_enabled"` + BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` + BalanceNotifyExtraEmails []string `json:"balance_notify_extra_emails"` + TotalRecharged float64 `json:"total_recharged"` APIKeys []APIKey `json:"api_keys,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"` diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 4fb72ce7..48528d55 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -33,9 +33,10 @@ type ChangePasswordRequest struct { // UpdateProfileRequest represents the update profile request payload type UpdateProfileRequest struct { - Username *string `json:"username"` - BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` - BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` + Username *string `json:"username"` + BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` + BalanceNotifyThresholdType *string `json:"balance_notify_threshold_type"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } // GetProfile handles getting user profile @@ -100,9 +101,10 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { } svcReq := service.UpdateProfileRequest{ - Username: req.Username, - BalanceNotifyEnabled: req.BalanceNotifyEnabled, - BalanceNotifyThreshold: req.BalanceNotifyThreshold, + Username: req.Username, + BalanceNotifyEnabled: req.BalanceNotifyEnabled, + BalanceNotifyThresholdType: req.BalanceNotifyThresholdType, + BalanceNotifyThreshold: req.BalanceNotifyThreshold, } updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq) if err != nil { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 752a5937..4ecab47a 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -641,22 +641,24 @@ func userEntityToService(u *dbent.User) *service.User { return nil } out := &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - TotpSecretEncrypted: u.TotpSecretEncrypted, - TotpEnabled: u.TotpEnabled, - TotpEnabledAt: u.TotpEnabledAt, - BalanceNotifyEnabled: u.BalanceNotifyEnabled, - BalanceNotifyThreshold: u.BalanceNotifyThreshold, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + BalanceNotifyEnabled: u.BalanceNotifyEnabled, + BalanceNotifyThresholdType: u.BalanceNotifyThresholdType, + BalanceNotifyThreshold: u.BalanceNotifyThreshold, + TotalRecharged: u.TotalRecharged, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } // Parse extra emails JSON array if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" { diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 2c544857..63168fb1 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -148,6 +148,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled). + SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType). SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)) if userIn.BalanceNotifyThreshold == nil { @@ -389,7 +390,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[ func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { client := clientFromContext(ctx, r.client) - n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) + update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount) + // Track cumulative recharge amount for percentage-based notifications + if amount > 0 { + update = update.AddTotalRecharged(amount) + } + n, err := update.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, nil) } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 8dd56b8f..7fbdd254 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -47,30 +47,21 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u if user == nil || s.emailService == nil || s.settingRepo == nil { return } - - // Check user-level switch if !user.BalanceNotifyEnabled { return } - // Check global switch - globalEnabled, threshold := s.getBalanceNotifyConfig(ctx) + globalEnabled, globalThresholdType, globalThresholdValue := s.getBalanceNotifyConfig(ctx) if !globalEnabled { return } - // User custom threshold overrides system default - if user.BalanceNotifyThreshold != nil { - threshold = *user.BalanceNotifyThreshold - } - + threshold := s.resolveEffectiveThreshold(user, globalThresholdType, globalThresholdValue) if threshold <= 0 { return } newBalance := oldBalance - cost - - // Only notify on first crossing if oldBalance >= threshold && newBalance < threshold { siteName := s.getSiteName(ctx) recipients := s.collectBalanceNotifyRecipients(user) @@ -85,6 +76,30 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u } } +// resolveEffectiveThreshold computes the actual USD threshold based on type and user settings. +func (s *BalanceNotifyService) resolveEffectiveThreshold(user *User, globalType string, globalValue float64) float64 { + // User-level override takes full precedence + if user.BalanceNotifyThreshold != nil { + thresholdType := user.BalanceNotifyThresholdType + if thresholdType == "" { + thresholdType = globalType + } + return computeThreshold(thresholdType, *user.BalanceNotifyThreshold, user.TotalRecharged) + } + return computeThreshold(globalType, globalValue, user.TotalRecharged) +} + +// computeThreshold converts a threshold value to USD based on type. +func computeThreshold(thresholdType string, value, totalRecharged float64) float64 { + if thresholdType == ThresholdTypePercentage { + if totalRecharged <= 0 { + return 0 // no recharge history → skip percentage check + } + return totalRecharged * value / 100 + } + return value // fixed USD amount +} + // quotaDim describes one quota dimension for notification checking. type quotaDim struct { name string @@ -139,13 +154,21 @@ func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, account } // getBalanceNotifyConfig reads global balance notification settings. -func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64) { - keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold} +func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, thresholdType string, threshold float64) { + keys := []string{ + SettingKeyBalanceLowNotifyEnabled, + SettingKeyBalanceLowNotifyThresholdType, + SettingKeyBalanceLowNotifyThreshold, + } settings, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { - return false, 0 + return false, ThresholdTypeFixed, 0 } enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" + thresholdType = settings[SettingKeyBalanceLowNotifyThresholdType] + if thresholdType == "" { + thresholdType = ThresholdTypeFixed + } if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" { if f, err := strconv.ParseFloat(v, 64); err == nil { threshold = f diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 2704e0d0..3de0e343 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -251,8 +251,13 @@ const ( SettingKeyEnableCCHSigning = "enable_cch_signing" // Balance Low Notification - SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 - SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD) + SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 + SettingKeyBalanceLowNotifyThresholdType = "balance_low_notify_threshold_type" // "fixed" | "percentage" + SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD 或百分比) + + // Threshold type constants + ThresholdTypeFixed = "fixed" + ThresholdTypePercentage = "percentage" // Account Quota Notification SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index bc4f53ce..e2491cbc 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -597,6 +597,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Balance low notification updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) + thresholdType := settings.BalanceLowNotifyThresholdType + if thresholdType == "" { + thresholdType = ThresholdTypeFixed + } + updates[SettingKeyBalanceLowNotifyThresholdType] = thresholdType updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64) accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails) if err != nil { @@ -1228,6 +1233,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin // Balance low notification result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" + result.BalanceLowNotifyThresholdType = settings[SettingKeyBalanceLowNotifyThresholdType] + if result.BalanceLowNotifyThresholdType == "" { + result.BalanceLowNotifyThresholdType = ThresholdTypeFixed + } if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 { result.BalanceLowNotifyThreshold = v } diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index debc2b19..b28d2247 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -108,8 +108,9 @@ type SystemSettings struct { EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) // Balance low notification - BalanceLowNotifyEnabled bool - BalanceLowNotifyThreshold float64 + BalanceLowNotifyEnabled bool + BalanceLowNotifyThresholdType string // "fixed" (default) | "percentage" + BalanceLowNotifyThreshold float64 // Account quota notification AccountQuotaNotifyEmails []string diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index b4818223..4ca31adc 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -31,9 +31,11 @@ type User struct { TotpEnabledAt *time.Time // TOTP 启用时间 // 余额不足通知 - BalanceNotifyEnabled bool - BalanceNotifyThreshold *float64 - BalanceNotifyExtraEmails []string + BalanceNotifyEnabled bool + BalanceNotifyThresholdType string // "fixed" (default) | "percentage" + BalanceNotifyThreshold *float64 + BalanceNotifyExtraEmails []string + TotalRecharged float64 APIKeys []APIKey Subscriptions []UserSubscription diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index e6b9a210..4669cb2b 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -62,11 +62,12 @@ type UserRepository interface { // UpdateProfileRequest 更新用户资料请求 type UpdateProfileRequest struct { - Email *string `json:"email"` - Username *string `json:"username"` - Concurrency *int `json:"concurrency"` - BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` - BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` + Email *string `json:"email"` + Username *string `json:"username"` + Concurrency *int `json:"concurrency"` + BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` + BalanceNotifyThresholdType *string `json:"balance_notify_threshold_type"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } // ChangePasswordRequest 修改密码请求 @@ -143,6 +144,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if req.BalanceNotifyEnabled != nil { user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled } + if req.BalanceNotifyThresholdType != nil { + user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType + } if req.BalanceNotifyThreshold != nil { if *req.BalanceNotifyThreshold <= 0 { user.BalanceNotifyThreshold = nil // clear to system default diff --git a/backend/migrations/102_add_balance_notify_threshold_type.sql b/backend/migrations/102_add_balance_notify_threshold_type.sql new file mode 100644 index 00000000..7ad70552 --- /dev/null +++ b/backend/migrations/102_add_balance_notify_threshold_type.sql @@ -0,0 +1,4 @@ +-- Add threshold type support (fixed / percentage) to balance notification +ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_threshold_type VARCHAR(10) NOT NULL DEFAULT 'fixed'; +-- Track cumulative recharge amount for percentage threshold calculation +ALTER TABLE users ADD COLUMN IF NOT EXISTS total_recharged DECIMAL(20,8) NOT NULL DEFAULT 0; diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 31284289..ec290be5 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -137,6 +137,7 @@ export interface SystemSettings { // Balance & quota notification balance_low_notify_enabled: boolean + balance_low_notify_threshold_type: 'fixed' | 'percentage' balance_low_notify_threshold: number account_quota_notify_emails: string[] } @@ -240,6 +241,7 @@ export interface UpdateSettingsRequest { payment_cancel_rate_limit_window_mode?: string // Balance & quota notification balance_low_notify_enabled?: boolean + balance_low_notify_threshold_type?: 'fixed' | 'percentage' balance_low_notify_threshold?: number account_quota_notify_emails?: string[] } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 8e10bf2a..880a81ee 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -4633,8 +4633,12 @@ export default { title: 'Balance Low Notification', description: 'Send email notification when user balance falls below threshold', enabled: 'Enable Balance Low Notification', + thresholdType: 'Threshold Type', + typeFixed: 'Fixed Amount', + typePercentage: 'Percentage of Recharged', threshold: 'Default Threshold', thresholdHint: 'Used when user has not set a custom value', + percentageHint: 'Notify when balance falls below this percentage of total recharged amount', thresholdPlaceholder: 'Enter amount', }, quotaNotify: { diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 1b82f419..41d94e06 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -4797,8 +4797,12 @@ export default { title: '余额不足提醒', description: '当用户余额低于阈值时发送邮件提醒', enabled: '启用余额不足提醒', - threshold: '默认提醒阈值', + thresholdType: '阈值类型', + typeFixed: '固定金额', + typePercentage: '充值百分比', + threshold: '提醒阈值', thresholdHint: '用户未自定义时使用此值', + percentageHint: '当余额低于累计充值额的此百分比时提醒', thresholdPlaceholder: '输入金额', }, quotaNotify: { diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 8ed77203..af84b67d 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2660,6 +2660,90 @@ + +
+
+

+ {{ t('admin.settings.balanceNotify.title') }} +

+

+ {{ t('admin.settings.balanceNotify.description') }} +

+
+
+
+ + +
+
+ +
+ +
+ + +
+
+ +
+ +
+ + {{ form.balance_low_notify_threshold_type === 'percentage' ? '%' : '$' }} + + +
+

+ {{ form.balance_low_notify_threshold_type === 'percentage' + ? t('admin.settings.balanceNotify.percentageHint') + : t('admin.settings.balanceNotify.thresholdHint') }} +

+
+
+
+
+ + +
+
+

+ {{ t('admin.settings.quotaNotify.title') }} +

+

+ {{ t('admin.settings.quotaNotify.description') }} +

+
+
+
+ +
+
+ + +
+ +
+

{{ t('admin.settings.quotaNotify.emailsHint') }}

+
+
+
@@ -2939,7 +3023,12 @@ const form = reactive({ // Gateway forwarding behavior enable_fingerprint_unification: true, enable_metadata_passthrough: false, - enable_cch_signing: false + enable_cch_signing: false, + // Balance & quota notification + balance_low_notify_enabled: false, + balance_low_notify_threshold_type: 'fixed' as 'fixed' | 'percentage', + balance_low_notify_threshold: 0, + account_quota_notify_emails: [] as string[] }) // Proxies for web search emulation ProxySelector @@ -3149,6 +3238,14 @@ function handleRegistrationEmailSuffixWhitelistPaste(event: ClipboardEvent) { } } +// Quota notify email helpers +const addQuotaNotifyEmail = () => { + if (!form.account_quota_notify_emails) { + form.account_quota_notify_emails = [] + } + form.account_quota_notify_emails.push('') +} + // LinuxDo OAuth redirect URL suggestion const linuxdoRedirectUrlSuggestion = computed(() => { if (typeof window === 'undefined') return '' @@ -3488,6 +3585,11 @@ async function saveSettings() { payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1, payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit, payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode, + // Balance & quota notification + balance_low_notify_enabled: form.balance_low_notify_enabled, + balance_low_notify_threshold_type: form.balance_low_notify_threshold_type, + balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, + account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''), } const updated = await adminAPI.settings.updateSettings(payload) From 9e33d0c4c079062f1d88fa7a2accd0bb2cb88b37 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 14:43:12 +0800 Subject: [PATCH 25/88] fix: address audit findings for websearch and balance notification - Fix GetByKeyForAuth not selecting balance notify fields (notifications never triggered in gateway path) - Fix provider-level ProxyURL never resolved: inject ProxyRepository into SettingService, resolve proxy URLs when building Manager - Fix admin manual balance adjustment not updating total_recharged - Add threshold_type input validation (reject invalid values) - Fix user threshold_type inheritance: custom threshold defaults to "fixed" instead of inheriting global type (prevents $5 being treated as 5%) - Add try-catch for clipboard.writeText (fails on non-HTTPS) - Add SetTotalRecharged to user Update for admin balance operations --- backend/cmd/server/wire_gen.go | 4 +- backend/internal/repository/api_key_repo.go | 5 ++ backend/internal/repository/user_repo.go | 3 +- backend/internal/server/http.go | 5 +- backend/internal/service/admin_service.go | 6 ++ .../service/balance_notify_service.go | 4 +- backend/internal/service/setting_service.go | 31 ++++++-- backend/internal/service/user_service.go | 4 +- backend/internal/service/websearch_config.go | 70 ++++++++++++------- backend/internal/service/wire.go | 5 +- frontend/src/views/admin/SettingsView.vue | 8 ++- 11 files changed, 102 insertions(+), 43 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 8c47b2bd..24ee02fd 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -50,7 +50,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) groupRepository := repository.NewGroupRepository(client, db) - settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) + proxyRepository := repository.NewProxyRepository(client, db) + settingService := service.ProvideSettingService(settingRepository, groupRepository, proxyRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -100,7 +101,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) - proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 4ecab47a..11eac7a8 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -143,6 +143,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se user.FieldRole, user.FieldBalance, user.FieldConcurrency, + user.FieldBalanceNotifyEnabled, + user.FieldBalanceNotifyThresholdType, + user.FieldBalanceNotifyThreshold, + user.FieldBalanceNotifyExtraEmails, + user.FieldTotalRecharged, ) }). WithGroup(func(q *dbent.GroupQuery) { diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 63168fb1..1792ef8d 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -150,7 +150,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled). SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType). SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). - SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)) + SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). + SetTotalRecharged(userIn.TotalRecharged) if userIn.BalanceNotifyThreshold == nil { updateOp = updateOp.ClearBalanceNotifyThreshold() } diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index ba45c31b..5165b059 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -59,7 +59,7 @@ func ProvideRouter( } // Wire up websearch Manager builder so it initializes on startup and rebuilds on config save. - settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig) { + settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) { if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 { service.SetWebSearchManager(nil) return @@ -80,6 +80,9 @@ func ProvideRouter( } if p.ProxyID != nil { pc.ProxyID = *p.ProxyID + if u, ok := proxyURLs[*p.ProxyID]; ok { + pc.ProxyURL = u + } } configs = append(configs, pc) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 97b42c24..a4e22b22 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -709,6 +709,12 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance) } + // Track cumulative recharge for percentage-based balance notifications + balanceDelta := user.Balance - oldBalance + if balanceDelta > 0 { + user.TotalRecharged += balanceDelta + } + if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 7fbdd254..e1f6bd8b 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -77,12 +77,12 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u } // resolveEffectiveThreshold computes the actual USD threshold based on type and user settings. +// When user sets a custom threshold, their type is used independently (defaults to "fixed" if unset). func (s *BalanceNotifyService) resolveEffectiveThreshold(user *User, globalType string, globalValue float64) float64 { - // User-level override takes full precedence if user.BalanceNotifyThreshold != nil { thresholdType := user.BalanceNotifyThresholdType if thresholdType == "" { - thresholdType = globalType + thresholdType = ThresholdTypeFixed // user custom value defaults to fixed, not inherited } return computeThreshold(thresholdType, *user.BalanceNotifyThreshold, user.TotalRecharged) } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index e2491cbc..9b307426 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface { GetByID(ctx context.Context, id int64) (*Group, error) } +// WebSearchManagerBuilder creates a websearch.Manager from config (injected by infra layer). +// proxyURLs maps proxy ID to resolved URL for provider-level proxy support. +type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[int64]string) + // SettingService 系统设置服务 type SettingService struct { - settingRepo SettingRepository - defaultSubGroupReader DefaultSubscriptionGroupReader - cfg *config.Config - onUpdate func() // Callback when settings are updated (for cache invalidation) - version string // Application version + settingRepo SettingRepository + defaultSubGroupReader DefaultSubscriptionGroupReader + proxyRepo ProxyRepository // for resolving websearch provider proxy URLs + cfg *config.Config + onUpdate func() // Callback when settings are updated (for cache invalidation) + version string // Application version + webSearchManagerBuilder WebSearchManagerBuilder } // NewSettingService 创建系统设置服务实例 @@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri s.defaultSubGroupReader = reader } +// SetProxyRepository injects a proxy repo for resolving websearch provider proxy URLs. +func (s *SettingService) SetProxyRepository(repo ProxyRepository) { + s.proxyRepo = repo +} + // GetAllSettings 获取所有系统设置 func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { settings, err := s.settingRepo.GetAll(ctx) @@ -598,7 +609,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Balance low notification updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) thresholdType := settings.BalanceLowNotifyThresholdType - if thresholdType == "" { + if thresholdType != ThresholdTypeFixed && thresholdType != ThresholdTypePercentage { thresholdType = ThresholdTypeFixed } updates[SettingKeyBalanceLowNotifyThresholdType] = thresholdType @@ -1231,6 +1242,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true" + // Web search emulation: quick enabled check from the JSON config + if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" { + var wsCfg WebSearchEmulationConfig + if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil { + result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0 + } + } + // Balance low notification result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" result.BalanceLowNotifyThresholdType = settings[SettingKeyBalanceLowNotifyThresholdType] diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 4669cb2b..26021a9b 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -145,7 +145,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled } if req.BalanceNotifyThresholdType != nil { - user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType + if *req.BalanceNotifyThresholdType == ThresholdTypeFixed || *req.BalanceNotifyThresholdType == ThresholdTypePercentage { + user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType + } } if req.BalanceNotifyThreshold != nil { if *req.BalanceNotifyThreshold <= 0 { diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go index bdfff3e4..346faf1f 100644 --- a/backend/internal/service/websearch_config.go +++ b/backend/internal/service/websearch_config.go @@ -10,7 +10,6 @@ import ( infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" - "github.com/redis/go-redis/v9" "golang.org/x/sync/singleflight" ) @@ -85,8 +84,7 @@ const ( // GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight. func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) { if cached := webSearchEmulationCache.Load(); cached != nil { - c := cached.(*cachedWebSearchEmulationConfig) - if time.Now().UnixNano() < c.expiresAt { + if c, ok := cached.(*cachedWebSearchEmulationConfig); ok && time.Now().UnixNano() < c.expiresAt { return c.config, nil } } @@ -96,7 +94,10 @@ func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebS if err != nil { return &WebSearchEmulationConfig{}, err } - return result.(*WebSearchEmulationConfig), nil + if cfg, ok := result.(*WebSearchEmulationConfig); ok { + return cfg, nil + } + return &WebSearchEmulationConfig{}, nil } func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) { @@ -154,7 +155,7 @@ func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg * }) // Hot-reload: rebuild the global Manager with new config - s.RebuildWebSearchManager(ctx) + s.rebuildWebSearchManager(ctx) return nil } @@ -196,34 +197,51 @@ func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool { return cfg.Enabled && len(cfg.Providers) > 0 } -// SetWebSearchRedisClient injects the Redis client used for quota tracking. -// Call after construction, before first use. Triggers initial Manager build. -func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) { - s.webSearchRedis = redisClient - s.RebuildWebSearchManager(ctx) +// SetWebSearchManagerBuilder injects a callback that creates and wires a websearch.Manager. +// The infra layer (main/wire) provides this builder, keeping redis out of the service layer. +// Triggers initial build. +func (s *SettingService) SetWebSearchManagerBuilder(ctx context.Context, builder WebSearchManagerBuilder) { + s.webSearchManagerBuilder = builder + s.rebuildWebSearchManager(ctx) } -// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager. -// Called on startup and after SaveWebSearchEmulationConfig. -func (s *SettingService) RebuildWebSearchManager(ctx context.Context) { +// rebuildWebSearchManager reads the current config, resolves proxy URLs, and invokes the builder. +func (s *SettingService) rebuildWebSearchManager(ctx context.Context) { + if s.webSearchManagerBuilder == nil { + return + } cfg, err := s.GetWebSearchEmulationConfig(ctx) - if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 { + if err != nil { SetWebSearchManager(nil) return } - providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers)) - for _, p := range cfg.Providers { - providerConfigs = append(providerConfigs, websearch.ProviderConfig{ - Type: p.Type, - APIKey: p.APIKey, - Priority: p.Priority, - QuotaLimit: p.QuotaLimit, - QuotaRefreshInterval: p.QuotaRefreshInterval, - ExpiresAt: p.ExpiresAt, - }) + proxyURLs := s.resolveProviderProxyURLs(ctx, cfg) + s.webSearchManagerBuilder(cfg, proxyURLs) +} + +// resolveProviderProxyURLs collects proxy IDs from providers and resolves them to URLs. +func (s *SettingService) resolveProviderProxyURLs(ctx context.Context, cfg *WebSearchEmulationConfig) map[int64]string { + if cfg == nil || s.proxyRepo == nil { + return nil } - SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis)) - slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs)) + var ids []int64 + for _, p := range cfg.Providers { + if p.ProxyID != nil && *p.ProxyID > 0 { + ids = append(ids, *p.ProxyID) + } + } + if len(ids) == 0 { + return nil + } + proxies, err := s.proxyRepo.ListByIDs(ctx, ids) + if err != nil { + return nil + } + result := make(map[int64]string, len(proxies)) + for _, px := range proxies { + result[px.ID] = px.URL() + } + return result } // WebSearchTestResult holds the result of a search test. diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 2827f135..b4e33039 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -373,10 +373,11 @@ func ProvideBackupService( return svc } -// ProvideSettingService wires SettingService with group reader for default subscription validation. -func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService { +// ProvideSettingService wires SettingService with group reader and proxy repo. +func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, proxyRepo ProxyRepository, cfg *config.Config) *SettingService { svc := NewSettingService(settingRepo, cfg) svc.SetDefaultSubscriptionGroupReader(groupRepo) + svc.SetProxyRepository(proxyRepo) return svc } diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index af84b67d..50b532fb 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3109,8 +3109,12 @@ async function copyApiKey(idx: number) { appStore.showError(t('admin.settings.webSearchEmulation.apiKeyPlaceholder')) return } - await navigator.clipboard.writeText(key) - appStore.showSuccess(t('admin.settings.webSearchEmulation.copied')) + try { + await navigator.clipboard.writeText(key) + appStore.showSuccess(t('admin.settings.webSearchEmulation.copied')) + } catch { + appStore.showError(t('common.error')) + } } async function testWebSearchProvider() { From cef22c70abb113eca5ff01262a5e5f31a63ce0fe Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 15:01:10 +0800 Subject: [PATCH 26/88] fix(notify): remove percentage threshold from balance notification Balance low notification only supports fixed USD amount threshold. Percentage threshold is a quota concept, not applicable to balance. Reverted threshold_type from admin settings, user profile, and all backend/frontend layers. DB fields (balance_notify_threshold_type, total_recharged) retained for potential future quota use. --- .../internal/handler/admin/setting_handler.go | 15 ++---- backend/internal/handler/dto/settings.go | 7 ++- backend/internal/handler/user_handler.go | 14 +++--- .../service/balance_notify_service.go | 46 ++++--------------- backend/internal/service/domain_constants.go | 9 +--- backend/internal/service/setting_service.go | 9 ---- backend/internal/service/settings_view.go | 5 +- backend/internal/service/user_service.go | 16 ++----- frontend/src/api/admin/settings.ts | 2 - frontend/src/i18n/locales/en.ts | 4 -- frontend/src/i18n/locales/zh.ts | 6 +-- frontend/src/views/admin/SettingsView.vue | 44 +++--------------- 12 files changed, 37 insertions(+), 140 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 3d587a21..49e7aeed 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -177,7 +177,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableCCHSigning: settings.EnableCCHSigning, WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, - BalanceLowNotifyThresholdType: settings.BalanceLowNotifyThresholdType, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, AccountQuotaNotifyEmails: settings.AccountQuotaNotifyEmails, PaymentEnabled: paymentCfg.Enabled, @@ -310,10 +309,9 @@ type UpdateSettingsRequest struct { EnableCCHSigning *bool `json:"enable_cch_signing"` // Balance low notification - BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThresholdType *string `json:"balance_low_notify_threshold_type"` - BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` - AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"` + BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` + AccountQuotaNotifyEmails *[]string `json:"account_quota_notify_emails"` // Payment configuration (integrated into settings, full replace) PaymentEnabled *bool `json:"payment_enabled"` @@ -898,12 +896,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.BalanceLowNotifyEnabled }(), - BalanceLowNotifyThresholdType: func() string { - if req.BalanceLowNotifyThresholdType != nil { - return *req.BalanceLowNotifyThresholdType - } - return previousSettings.BalanceLowNotifyThresholdType - }(), BalanceLowNotifyThreshold: func() float64 { if req.BalanceLowNotifyThreshold != nil { return *req.BalanceLowNotifyThreshold @@ -1063,7 +1055,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableCCHSigning: updatedSettings.EnableCCHSigning, BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, - BalanceLowNotifyThresholdType: updatedSettings.BalanceLowNotifyThresholdType, BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, AccountQuotaNotifyEmails: updatedSettings.AccountQuotaNotifyEmails, PaymentEnabled: updatedPaymentCfg.Enabled, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 8da7c6f2..e29f72da 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -150,10 +150,9 @@ type SystemSettings struct { PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"` // Balance low notification - BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThresholdType string `json:"balance_low_notify_threshold_type"` - BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` - AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + AccountQuotaNotifyEmails []string `json:"account_quota_notify_emails"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 48528d55..4fb72ce7 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -33,10 +33,9 @@ type ChangePasswordRequest struct { // UpdateProfileRequest represents the update profile request payload type UpdateProfileRequest struct { - Username *string `json:"username"` - BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` - BalanceNotifyThresholdType *string `json:"balance_notify_threshold_type"` - BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` + Username *string `json:"username"` + BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } // GetProfile handles getting user profile @@ -101,10 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { } svcReq := service.UpdateProfileRequest{ - Username: req.Username, - BalanceNotifyEnabled: req.BalanceNotifyEnabled, - BalanceNotifyThresholdType: req.BalanceNotifyThresholdType, - BalanceNotifyThreshold: req.BalanceNotifyThreshold, + Username: req.Username, + BalanceNotifyEnabled: req.BalanceNotifyEnabled, + BalanceNotifyThreshold: req.BalanceNotifyThreshold, } updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq) if err != nil { diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index e1f6bd8b..0d7e4c09 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -51,12 +51,16 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u return } - globalEnabled, globalThresholdType, globalThresholdValue := s.getBalanceNotifyConfig(ctx) + globalEnabled, globalThreshold := s.getBalanceNotifyConfig(ctx) if !globalEnabled { return } - threshold := s.resolveEffectiveThreshold(user, globalThresholdType, globalThresholdValue) + // User custom threshold overrides system default + threshold := globalThreshold + if user.BalanceNotifyThreshold != nil { + threshold = *user.BalanceNotifyThreshold + } if threshold <= 0 { return } @@ -76,30 +80,6 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u } } -// resolveEffectiveThreshold computes the actual USD threshold based on type and user settings. -// When user sets a custom threshold, their type is used independently (defaults to "fixed" if unset). -func (s *BalanceNotifyService) resolveEffectiveThreshold(user *User, globalType string, globalValue float64) float64 { - if user.BalanceNotifyThreshold != nil { - thresholdType := user.BalanceNotifyThresholdType - if thresholdType == "" { - thresholdType = ThresholdTypeFixed // user custom value defaults to fixed, not inherited - } - return computeThreshold(thresholdType, *user.BalanceNotifyThreshold, user.TotalRecharged) - } - return computeThreshold(globalType, globalValue, user.TotalRecharged) -} - -// computeThreshold converts a threshold value to USD based on type. -func computeThreshold(thresholdType string, value, totalRecharged float64) float64 { - if thresholdType == ThresholdTypePercentage { - if totalRecharged <= 0 { - return 0 // no recharge history → skip percentage check - } - return totalRecharged * value / 100 - } - return value // fixed USD amount -} - // quotaDim describes one quota dimension for notification checking. type quotaDim struct { name string @@ -154,21 +134,13 @@ func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, account } // getBalanceNotifyConfig reads global balance notification settings. -func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, thresholdType string, threshold float64) { - keys := []string{ - SettingKeyBalanceLowNotifyEnabled, - SettingKeyBalanceLowNotifyThresholdType, - SettingKeyBalanceLowNotifyThreshold, - } +func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64) { + keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold} settings, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { - return false, ThresholdTypeFixed, 0 + return false, 0 } enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" - thresholdType = settings[SettingKeyBalanceLowNotifyThresholdType] - if thresholdType == "" { - thresholdType = ThresholdTypeFixed - } if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" { if f, err := strconv.ParseFloat(v, 64); err == nil { threshold = f diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 3de0e343..2704e0d0 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -251,13 +251,8 @@ const ( SettingKeyEnableCCHSigning = "enable_cch_signing" // Balance Low Notification - SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 - SettingKeyBalanceLowNotifyThresholdType = "balance_low_notify_threshold_type" // "fixed" | "percentage" - SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD 或百分比) - - // Threshold type constants - ThresholdTypeFixed = "fixed" - ThresholdTypePercentage = "percentage" + SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 + SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD) // Account Quota Notification SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 9b307426..f0cf750a 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -608,11 +608,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Balance low notification updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) - thresholdType := settings.BalanceLowNotifyThresholdType - if thresholdType != ThresholdTypeFixed && thresholdType != ThresholdTypePercentage { - thresholdType = ThresholdTypeFixed - } - updates[SettingKeyBalanceLowNotifyThresholdType] = thresholdType updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64) accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails) if err != nil { @@ -1252,10 +1247,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin // Balance low notification result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" - result.BalanceLowNotifyThresholdType = settings[SettingKeyBalanceLowNotifyThresholdType] - if result.BalanceLowNotifyThresholdType == "" { - result.BalanceLowNotifyThresholdType = ThresholdTypeFixed - } if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 { result.BalanceLowNotifyThreshold = v } diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index b28d2247..debc2b19 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -108,9 +108,8 @@ type SystemSettings struct { EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) // Balance low notification - BalanceLowNotifyEnabled bool - BalanceLowNotifyThresholdType string // "fixed" (default) | "percentage" - BalanceLowNotifyThreshold float64 + BalanceLowNotifyEnabled bool + BalanceLowNotifyThreshold float64 // Account quota notification AccountQuotaNotifyEmails []string diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 26021a9b..e6b9a210 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -62,12 +62,11 @@ type UserRepository interface { // UpdateProfileRequest 更新用户资料请求 type UpdateProfileRequest struct { - Email *string `json:"email"` - Username *string `json:"username"` - Concurrency *int `json:"concurrency"` - BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` - BalanceNotifyThresholdType *string `json:"balance_notify_threshold_type"` - BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` + Email *string `json:"email"` + Username *string `json:"username"` + Concurrency *int `json:"concurrency"` + BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } // ChangePasswordRequest 修改密码请求 @@ -144,11 +143,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if req.BalanceNotifyEnabled != nil { user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled } - if req.BalanceNotifyThresholdType != nil { - if *req.BalanceNotifyThresholdType == ThresholdTypeFixed || *req.BalanceNotifyThresholdType == ThresholdTypePercentage { - user.BalanceNotifyThresholdType = *req.BalanceNotifyThresholdType - } - } if req.BalanceNotifyThreshold != nil { if *req.BalanceNotifyThreshold <= 0 { user.BalanceNotifyThreshold = nil // clear to system default diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index ec290be5..31284289 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -137,7 +137,6 @@ export interface SystemSettings { // Balance & quota notification balance_low_notify_enabled: boolean - balance_low_notify_threshold_type: 'fixed' | 'percentage' balance_low_notify_threshold: number account_quota_notify_emails: string[] } @@ -241,7 +240,6 @@ export interface UpdateSettingsRequest { payment_cancel_rate_limit_window_mode?: string // Balance & quota notification balance_low_notify_enabled?: boolean - balance_low_notify_threshold_type?: 'fixed' | 'percentage' balance_low_notify_threshold?: number account_quota_notify_emails?: string[] } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 880a81ee..8e10bf2a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -4633,12 +4633,8 @@ export default { title: 'Balance Low Notification', description: 'Send email notification when user balance falls below threshold', enabled: 'Enable Balance Low Notification', - thresholdType: 'Threshold Type', - typeFixed: 'Fixed Amount', - typePercentage: 'Percentage of Recharged', threshold: 'Default Threshold', thresholdHint: 'Used when user has not set a custom value', - percentageHint: 'Notify when balance falls below this percentage of total recharged amount', thresholdPlaceholder: 'Enter amount', }, quotaNotify: { diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 41d94e06..1b82f419 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -4797,12 +4797,8 @@ export default { title: '余额不足提醒', description: '当用户余额低于阈值时发送邮件提醒', enabled: '启用余额不足提醒', - thresholdType: '阈值类型', - typeFixed: '固定金额', - typePercentage: '充值百分比', - threshold: '提醒阈值', + threshold: '默认提醒阈值', thresholdHint: '用户未自定义时使用此值', - percentageHint: '当余额低于累计充值额的此百分比时提醒', thresholdPlaceholder: '输入金额', }, quotaNotify: { diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 50b532fb..faab49fe 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2675,43 +2675,13 @@ -
- -
- -
- - -
-
- -
- -
- - {{ form.balance_low_notify_threshold_type === 'percentage' ? '%' : '$' }} - - -
-

- {{ form.balance_low_notify_threshold_type === 'percentage' - ? t('admin.settings.balanceNotify.percentageHint') - : t('admin.settings.balanceNotify.thresholdHint') }} -

+
+ +
+ $ +
+

{{ t('admin.settings.balanceNotify.thresholdHint') }}

@@ -3026,7 +2996,6 @@ const form = reactive({ enable_cch_signing: false, // Balance & quota notification balance_low_notify_enabled: false, - balance_low_notify_threshold_type: 'fixed' as 'fixed' | 'percentage', balance_low_notify_threshold: 0, account_quota_notify_emails: [] as string[] }) @@ -3591,7 +3560,6 @@ async function saveSettings() { payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode, // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, - balance_low_notify_threshold_type: form.balance_low_notify_threshold_type, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e: string) => e.trim() !== ''), } From 889b5b4f3bfd6f3f5422b65fea44419710a55c7f Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 15:59:45 +0800 Subject: [PATCH 27/88] fix(websearch): improve settings UI and hide config when globally disabled - API Key show/copy buttons moved inside input field (inline icons) - Proxy selector and test button on same row to save vertical space - Test opens a dialog modal instead of inline display - Hide all websearch config in channels/accounts when global toggle is off --- backend/cmd/server/VERSION | 2 +- .../components/account/CreateAccountModal.vue | 10 +- .../components/account/EditAccountModal.vue | 10 +- frontend/src/views/admin/ChannelsView.vue | 86 ++++++++-- frontend/src/views/admin/SettingsView.vue | 155 ++++++++++-------- 5 files changed, 177 insertions(+), 86 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index e534f2aa..061bed0b 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.108.140 +0.1.110.11 diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index d5d8ff89..e0dbce61 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2325,9 +2325,9 @@ - +
@@ -2998,6 +2998,12 @@ const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OF const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) const webSearchEmulationEnabled = ref(false) +const webSearchGlobalEnabled = ref(false) + +// Load web search global state once +adminAPI.settings.getWebSearchEmulationConfig().then(cfg => { + webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0 +}).catch(() => { webSearchGlobalEnabled.value = false }) const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 54dae5c2..086575e6 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1149,9 +1149,9 @@
- +
@@ -1975,6 +1975,12 @@ const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OF const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) const webSearchEmulationEnabled = ref(false) +const webSearchGlobalEnabled = ref(false) + +// Load web search global state once +adminAPI.settings.getWebSearchEmulationConfig().then(cfg => { + webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0 +}).catch(() => { webSearchGlobalEnabled.value = false }) const editQuotaLimit = ref(null) const editQuotaDailyLimit = ref(null) const editQuotaWeeklyLimit = ref(null) diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index a49e1694..639ace4a 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -306,6 +306,21 @@
+ +
+
+
+ +

+ {{ t('admin.channels.form.webSearchEmulationHint') }} +

+
+ +
+
+
@@ -560,6 +575,7 @@ import { ref, reactive, computed, onMounted, onUnmounted } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' +import { extractApiErrorMessage } from '@/utils/apiError' import { adminAPI } from '@/api/admin' import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest, AccountStatsPricingRule } from '@/api/admin/channels' import type { PricingFormEntry } from '@/components/admin/channel/types' @@ -583,6 +599,18 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize' const { t } = useI18n() const appStore = useAppStore() +// Web Search global enabled state (loaded once on mount) +const webSearchGlobalEnabled = ref(false) +async function loadWebSearchGlobalState() { + try { + const cfg = await adminAPI.settings.getWebSearchEmulationConfig() + webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0 + } catch (err: unknown) { + console.warn('Failed to load web search global state:', err) + webSearchGlobalEnabled.value = false + } +} + // ── Platform Section type ── interface PlatformSection { platform: GroupPlatform @@ -591,6 +619,7 @@ interface PlatformSection { group_ids: number[] model_mapping: Record model_pricing: PricingFormEntry[] + web_search_emulation: boolean } // ── Table columns ── @@ -709,7 +738,8 @@ function addPlatformSection(platform: GroupPlatform) { collapsed: false, group_ids: [], model_mapping: {}, - model_pricing: [] + model_pricing: [], + web_search_emulation: false, }) } @@ -901,10 +931,14 @@ function accountStatsRulesToAPI(): AccountStatsPricingRule[] { } // ── Form ↔ API conversion ── -function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record> } { +function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record>, features_config: Record } { const group_ids: number[] = [] const model_pricing: ChannelModelPricing[] = [] const model_mapping: Record> = {} + // Preserve existing features_config fields not managed by the form + const featuresConfig: Record = editingChannel.value?.features_config + ? { ...editingChannel.value.features_config } + : {} for (const section of form.platforms) { if (!section.enabled) continue @@ -933,7 +967,19 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[ } } - return { group_ids, model_pricing, model_mapping } + // Collect web_search_emulation (only anthropic platform supports it) + const wsEmulation: Record = {} + for (const section of form.platforms) { + if (!section.enabled) continue + if (section.web_search_emulation && section.platform === 'anthropic') { + wsEmulation[section.platform] = true + } + } + if (Object.keys(wsEmulation).length > 0) { + featuresConfig.web_search_emulation = wsEmulation + } + + return { group_ids, model_pricing, model_mapping, features_config: featuresConfig } } function apiToForm(channel: Channel): PlatformSection[] { @@ -977,13 +1023,19 @@ function apiToForm(channel: Channel): PlatformSection[] { intervals: apiIntervalsToForm(p.intervals || []) } as PricingFormEntry)) + // Read web_search_emulation from features_config + const fc = channel.features_config + const wsEmulation = fc?.web_search_emulation as Record | undefined + const webSearchEnabled = wsEmulation?.[platform] === true + sections.push({ platform, enabled: true, collapsed: false, group_ids: groupIds, model_mapping: { ...mapping }, - model_pricing: pricing + model_pricing: pricing, + web_search_emulation: webSearchEnabled, }) } @@ -1008,10 +1060,10 @@ async function loadChannels() { if (ctrl.signal.aborted || abortController !== ctrl) return channels.value = response.items || [] pagination.total = response.total - } catch (error: any) { - if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return - appStore.showError(t('admin.channels.loadError', 'Failed to load channels')) - console.error('Error loading channels:', error) + } catch (error: unknown) { + const e = error as { name?: string; code?: string } + if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return + appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels'))) } finally { if (abortController === ctrl) { loading.value = false @@ -1210,7 +1262,7 @@ async function handleSubmit() { } } - const { group_ids, model_pricing, model_mapping } = formToAPI() + const { group_ids, model_pricing, model_mapping, features_config } = formToAPI() submitting.value = true try { @@ -1224,6 +1276,7 @@ async function handleSubmit() { model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {}, billing_model_source: form.billing_model_source, restrict_models: form.restrict_models, + features_config, apply_pricing_to_account_stats: form.apply_pricing_to_account_stats, account_stats_pricing_rules: accountStatsRulesToAPI() } @@ -1238,6 +1291,7 @@ async function handleSubmit() { model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {}, billing_model_source: form.billing_model_source, restrict_models: form.restrict_models, + features_config, apply_pricing_to_account_stats: form.apply_pricing_to_account_stats, account_stats_pricing_rules: accountStatsRulesToAPI() } @@ -1246,12 +1300,10 @@ async function handleSubmit() { } closeDialog() loadChannels() - } catch (error: any) { - const msg = error.response?.data?.detail || (editingChannel.value + } catch (error: unknown) { + appStore.showError(extractApiErrorMessage(error, editingChannel.value ? t('admin.channels.updateError', 'Failed to update channel') - : t('admin.channels.createError', 'Failed to create channel')) - appStore.showError(msg) - console.error('Error saving channel:', error) + : t('admin.channels.createError', 'Failed to create channel'))) } finally { submitting.value = false } @@ -1289,9 +1341,8 @@ async function confirmDelete() { showDeleteDialog.value = false deletingChannel.value = null loadChannels() - } catch (error: any) { - appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel')) - console.error('Error deleting channel:', error) + } catch (error: unknown) { + appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel'))) } } @@ -1299,6 +1350,7 @@ async function confirmDelete() { onMounted(() => { loadChannels() loadGroups() + loadWebSearchGlobalState() }) onUnmounted(() => { diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index faab49fe..b5181ccc 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1789,40 +1789,42 @@
- +
-
+
- - +
+ + +
@@ -1858,44 +1860,19 @@ {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }}
- -
- - -
- - -
-
- - -
- -
-

- {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }} -

-
- {{ t('admin.settings.webSearchEmulation.testNoResults') }} -
-
- {{ r.title }} -

{{ r.snippet && r.snippet.length > 120 ? r.snippet.slice(0, 120) + '...' : r.snippet }}

-
+ +
+
+ +
+
@@ -1903,6 +1880,50 @@
+ +
+
+

+ {{ t('admin.settings.webSearchEmulation.testResultTitle') }} +

+
+ + +
+ +
+

+ {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }} +

+
+ {{ t('admin.settings.webSearchEmulation.testNoResults') }} +
+
+ {{ r.title }} +

{{ r.snippet }}

+
+
+
+ +
+
+
+
@@ -3016,6 +3037,12 @@ const apiKeyVisible = reactive>({}) const wsTestQuery = ref('') const wsTestLoading = ref(false) const wsTestResult = ref(null) +const wsTestDialogOpen = ref(false) + +function openTestDialog() { + wsTestResult.value = null + wsTestDialogOpen.value = true +} function toggleProviderExpand(idx: number) { expandedProviders[idx] = !expandedProviders[idx] From eba289a7ff82430a1570d84a97c4bdf2d10e9775 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 17:49:58 +0800 Subject: [PATCH 28/88] feat(notify): add global toggles, percentage threshold, and visibility control - Add global toggle for account quota notification in admin settings - Add percentage-based threshold type for per-account quota alerts - Hide balance notify card on user profile when global toggle is off - Expose balance_low_notify_enabled and account_quota_notify_enabled in PublicSettings - Add threshold type (fixed/percentage) to QuotaNotifyToggle with $ / % switcher --- backend/internal/service/account.go | 20 ++++++++ .../service/balance_notify_service.go | 50 ++++++++++++++----- backend/internal/service/domain_constants.go | 3 +- backend/internal/service/setting_service.go | 12 ++++- backend/internal/service/settings_view.go | 6 ++- frontend/src/api/admin/settings.ts | 2 + .../components/account/EditAccountModal.vue | 24 +++++++++ .../src/components/account/QuotaLimitCard.vue | 15 ++++++ .../components/account/QuotaNotifyToggle.vue | 31 ++++++++++-- frontend/src/i18n/locales/en.ts | 3 +- frontend/src/i18n/locales/zh.ts | 3 +- frontend/src/stores/app.ts | 4 +- frontend/src/types/index.ts | 2 + frontend/src/views/admin/SettingsView.vue | 10 +++- frontend/src/views/user/ProfileView.vue | 5 +- 15 files changed, 164 insertions(+), 26 deletions(-) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 0b225dac..6e5a768f 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1432,6 +1432,14 @@ func (a *Account) getExtraString(key string) string { return "" } +// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal +func (a *Account) getExtraStringDefault(key, defaultVal string) string { + if v := a.getExtraString(key); v != "" { + return v + } + return defaultVal +} + // getExtraInt 从 Extra 中读取指定 key 的 int 值 func (a *Account) getExtraInt(key string) int { if a.Extra == nil { @@ -1498,6 +1506,10 @@ func (a *Account) GetQuotaNotifyDailyThreshold() float64 { return a.getExtraFloat64("quota_notify_daily_threshold") } +func (a *Account) GetQuotaNotifyDailyThresholdType() string { + return a.getExtraStringDefault("quota_notify_daily_threshold_type", "fixed") +} + func (a *Account) GetQuotaNotifyWeeklyEnabled() bool { return a.getExtraBool("quota_notify_weekly_enabled") } @@ -1506,6 +1518,10 @@ func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 { return a.getExtraFloat64("quota_notify_weekly_threshold") } +func (a *Account) GetQuotaNotifyWeeklyThresholdType() string { + return a.getExtraStringDefault("quota_notify_weekly_threshold_type", "fixed") +} + func (a *Account) GetQuotaNotifyTotalEnabled() bool { return a.getExtraBool("quota_notify_total_enabled") } @@ -1514,6 +1530,10 @@ func (a *Account) GetQuotaNotifyTotalThreshold() float64 { return a.getExtraFloat64("quota_notify_total_threshold") } +func (a *Account) GetQuotaNotifyTotalThresholdType() string { + return a.getExtraStringDefault("quota_notify_total_threshold_type", "fixed") +} + // nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点 func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time { t := after.In(tz) diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 0d7e4c09..65cec594 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -82,19 +82,29 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u // quotaDim describes one quota dimension for notification checking. type quotaDim struct { - name string - enabled bool - threshold float64 - oldUsed float64 - limit float64 + name string + enabled bool + threshold float64 + thresholdType string // "fixed" (default) or "percentage" + oldUsed float64 + limit float64 +} + +// resolvedThreshold returns the effective threshold value. +// For percentage type, it computes threshold = limit * percentage / 100. +func (d quotaDim) resolvedThreshold() float64 { + if d.thresholdType == "percentage" && d.limit > 0 { + return d.limit * d.threshold / 100 + } + return d.threshold } // buildQuotaDims returns the three quota dimensions for notification checking. func buildQuotaDims(account *Account) []quotaDim { return []quotaDim{ - {quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()}, - {quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()}, - {quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaUsed(), account.GetQuotaLimit()}, + {quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()}, + {quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()}, + {quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), account.GetQuotaUsed(), account.GetQuotaLimit()}, } } @@ -104,6 +114,9 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 { return } + if !s.isAccountQuotaNotifyEnabled(ctx) { + return + } adminEmails := s.getAccountQuotaNotifyEmails(ctx) if len(adminEmails) == 0 { return @@ -114,22 +127,26 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte if !dim.enabled || dim.threshold <= 0 { continue } + effectiveThreshold := dim.resolvedThreshold() + if effectiveThreshold <= 0 { + continue + } newUsed := dim.oldUsed + cost - if dim.oldUsed < dim.threshold && newUsed >= dim.threshold { - s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, siteName) + if dim.oldUsed < effectiveThreshold && newUsed >= effectiveThreshold { + s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, effectiveThreshold, siteName) } } } // asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery. -func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountName string, dim quotaDim, newUsed float64, siteName string) { +func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountName string, dim quotaDim, newUsed, effectiveThreshold float64, siteName string) { go func() { defer func() { if r := recover(); r != nil { slog.Error("panic in quota notification", "recover", r) } }() - s.sendQuotaAlertEmails(adminEmails, accountName, dim.name, newUsed, dim.limit, dim.threshold, siteName) + s.sendQuotaAlertEmails(adminEmails, accountName, dim.name, newUsed, dim.limit, effectiveThreshold, siteName) }() } @@ -149,6 +166,15 @@ func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enab return } +// isAccountQuotaNotifyEnabled checks the global account quota notification toggle. +func (s *BalanceNotifyService) isAccountQuotaNotifyEnabled(ctx context.Context) bool { + val, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEnabled) + if err != nil { + return false + } + return val == "true" +} + // getAccountQuotaNotifyEmails reads admin notification emails from settings. func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string { raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 2704e0d0..f07ddfd4 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -255,7 +255,8 @@ const ( SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD) // Account Quota Notification - SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组) + SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关 + SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组) ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f0cf750a..abcae9c1 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -182,6 +182,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingPaymentEnabled, SettingKeyOIDCConnectEnabled, SettingKeyOIDCConnectProviderName, + SettingKeyBalanceLowNotifyEnabled, + SettingKeyAccountQuotaNotifyEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -249,6 +251,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings PaymentEnabled: settings[SettingPaymentEnabled] == "true", OIDCOAuthEnabled: oidcEnabled, OIDCOAuthProviderName: oidcProviderName, + BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true", + AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true", }, nil } @@ -302,6 +306,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` Version string `json:"version,omitempty"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` }{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, @@ -332,6 +338,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, Version: s.version, + BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, + AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, }, nil } @@ -609,6 +617,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Balance low notification updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64) + updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails) if err != nil { return fmt.Errorf("marshal account quota notify emails: %w", err) @@ -1251,7 +1260,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.BalanceLowNotifyThreshold = v } - // Account quota notification emails + // Account quota notification + result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true" if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" { var emails []string if err := json.Unmarshal([]byte(raw), &emails); err == nil { diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index debc2b19..b79b930a 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -112,7 +112,8 @@ type SystemSettings struct { BalanceLowNotifyThreshold float64 // Account quota notification - AccountQuotaNotifyEmails []string + AccountQuotaNotifyEnabled bool + AccountQuotaNotifyEmails []string } type DefaultSubscriptionSetting struct { @@ -152,6 +153,9 @@ type PublicSettings struct { OIDCOAuthEnabled bool OIDCOAuthProviderName string Version string + + BalanceLowNotifyEnabled bool + AccountQuotaNotifyEnabled bool } // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 31284289..5c5de2d1 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -138,6 +138,7 @@ export interface SystemSettings { // Balance & quota notification balance_low_notify_enabled: boolean balance_low_notify_threshold: number + account_quota_notify_enabled: boolean account_quota_notify_emails: string[] } @@ -241,6 +242,7 @@ export interface UpdateSettingsRequest { // Balance & quota notification balance_low_notify_enabled?: boolean balance_low_notify_threshold?: number + account_quota_notify_enabled?: boolean account_quota_notify_emails?: string[] } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 086575e6..abb9569e 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1188,10 +1188,13 @@ :resetTimezone="editResetTimezone" :quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled" :quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold" + :quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType" :quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled" :quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold" + :quotaNotifyWeeklyThresholdType="editQuotaNotifyWeeklyThresholdType" :quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled" :quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold" + :quotaNotifyTotalThresholdType="editQuotaNotifyTotalThresholdType" @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" @@ -1203,10 +1206,13 @@ @update:resetTimezone="editResetTimezone = $event" @update:quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled = $event" @update:quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold = $event" + @update:quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType = $event" @update:quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled = $event" @update:quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold = $event" + @update:quotaNotifyWeeklyThresholdType="editQuotaNotifyWeeklyThresholdType = $event" @update:quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled = $event" @update:quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold = $event" + @update:quotaNotifyTotalThresholdType="editQuotaNotifyTotalThresholdType = $event" /> @@ -1232,10 +1238,13 @@ :resetTimezone="editResetTimezone" :quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled" :quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold" + :quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType" :quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled" :quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold" + :quotaNotifyWeeklyThresholdType="editQuotaNotifyWeeklyThresholdType" :quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled" :quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold" + :quotaNotifyTotalThresholdType="editQuotaNotifyTotalThresholdType" @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" @@ -1247,10 +1256,13 @@ @update:resetTimezone="editResetTimezone = $event" @update:quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled = $event" @update:quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold = $event" + @update:quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType = $event" @update:quotaNotifyWeeklyEnabled="editQuotaNotifyWeeklyEnabled = $event" @update:quotaNotifyWeeklyThreshold="editQuotaNotifyWeeklyThreshold = $event" + @update:quotaNotifyWeeklyThresholdType="editQuotaNotifyWeeklyThresholdType = $event" @update:quotaNotifyTotalEnabled="editQuotaNotifyTotalEnabled = $event" @update:quotaNotifyTotalThreshold="editQuotaNotifyTotalThreshold = $event" + @update:quotaNotifyTotalThresholdType="editQuotaNotifyTotalThresholdType = $event" /> @@ -1992,10 +2004,13 @@ const editWeeklyResetHour = ref(null) const editResetTimezone = ref(null) const editQuotaNotifyDailyEnabled = ref(null) const editQuotaNotifyDailyThreshold = ref(null) +const editQuotaNotifyDailyThresholdType = ref(null) const editQuotaNotifyWeeklyEnabled = ref(null) const editQuotaNotifyWeeklyThreshold = ref(null) +const editQuotaNotifyWeeklyThresholdType = ref(null) const editQuotaNotifyTotalEnabled = ref(null) const editQuotaNotifyTotalThreshold = ref(null) +const editQuotaNotifyTotalThresholdType = ref(null) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 @@ -2198,10 +2213,13 @@ const syncFormFromAccount = (newAccount: Account | null) => { // Load quota notify config editQuotaNotifyDailyEnabled.value = (extra?.quota_notify_daily_enabled as boolean) ?? null editQuotaNotifyDailyThreshold.value = (extra?.quota_notify_daily_threshold as number) ?? null + editQuotaNotifyDailyThresholdType.value = (extra?.quota_notify_daily_threshold_type as string) ?? null editQuotaNotifyWeeklyEnabled.value = (extra?.quota_notify_weekly_enabled as boolean) ?? null editQuotaNotifyWeeklyThreshold.value = (extra?.quota_notify_weekly_threshold as number) ?? null + editQuotaNotifyWeeklyThresholdType.value = (extra?.quota_notify_weekly_threshold_type as string) ?? null editQuotaNotifyTotalEnabled.value = (extra?.quota_notify_total_enabled as boolean) ?? null editQuotaNotifyTotalThreshold.value = (extra?.quota_notify_total_threshold as number) ?? null + editQuotaNotifyTotalThresholdType.value = (extra?.quota_notify_total_threshold_type as string) ?? null } else { editQuotaLimit.value = null editQuotaDailyLimit.value = null @@ -3262,9 +3280,11 @@ const handleSubmit = async () => { } else { delete newExtra.quota_notify_daily_threshold } + newExtra.quota_notify_daily_threshold_type = editQuotaNotifyDailyThresholdType.value || 'fixed' } else { delete newExtra.quota_notify_daily_enabled delete newExtra.quota_notify_daily_threshold + delete newExtra.quota_notify_daily_threshold_type } if (editQuotaNotifyWeeklyEnabled.value) { newExtra.quota_notify_weekly_enabled = true @@ -3273,9 +3293,11 @@ const handleSubmit = async () => { } else { delete newExtra.quota_notify_weekly_threshold } + newExtra.quota_notify_weekly_threshold_type = editQuotaNotifyWeeklyThresholdType.value || 'fixed' } else { delete newExtra.quota_notify_weekly_enabled delete newExtra.quota_notify_weekly_threshold + delete newExtra.quota_notify_weekly_threshold_type } if (editQuotaNotifyTotalEnabled.value) { newExtra.quota_notify_total_enabled = true @@ -3284,9 +3306,11 @@ const handleSubmit = async () => { } else { delete newExtra.quota_notify_total_threshold } + newExtra.quota_notify_total_threshold_type = editQuotaNotifyTotalThresholdType.value || 'fixed' } else { delete newExtra.quota_notify_total_enabled delete newExtra.quota_notify_total_threshold + delete newExtra.quota_notify_total_threshold_type } updatePayload.extra = newExtra } diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 7c3afd23..64bdb08a 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -17,17 +17,23 @@ const props = withDefaults(defineProps<{ resetTimezone: string | null quotaNotifyDailyEnabled?: boolean | null quotaNotifyDailyThreshold?: number | null + quotaNotifyDailyThresholdType?: string | null quotaNotifyWeeklyEnabled?: boolean | null quotaNotifyWeeklyThreshold?: number | null + quotaNotifyWeeklyThresholdType?: string | null quotaNotifyTotalEnabled?: boolean | null quotaNotifyTotalThreshold?: number | null + quotaNotifyTotalThresholdType?: string | null }>(), { quotaNotifyDailyEnabled: null, quotaNotifyDailyThreshold: null, + quotaNotifyDailyThresholdType: null, quotaNotifyWeeklyEnabled: null, quotaNotifyWeeklyThreshold: null, + quotaNotifyWeeklyThresholdType: null, quotaNotifyTotalEnabled: null, quotaNotifyTotalThreshold: null, + quotaNotifyTotalThresholdType: null, }) const emit = defineEmits<{ @@ -42,10 +48,13 @@ const emit = defineEmits<{ 'update:resetTimezone': [value: string | null] 'update:quotaNotifyDailyEnabled': [value: boolean | null] 'update:quotaNotifyDailyThreshold': [value: number | null] + 'update:quotaNotifyDailyThresholdType': [value: string | null] 'update:quotaNotifyWeeklyEnabled': [value: boolean | null] 'update:quotaNotifyWeeklyThreshold': [value: number | null] + 'update:quotaNotifyWeeklyThresholdType': [value: string | null] 'update:quotaNotifyTotalEnabled': [value: boolean | null] 'update:quotaNotifyTotalThreshold': [value: number | null] + 'update:quotaNotifyTotalThresholdType': [value: string | null] }>() const enabled = computed(() => @@ -228,8 +237,10 @@ const onWeeklyModeChange = (e: Event) => { v-if="dailyLimit && dailyLimit > 0" :enabled="props.quotaNotifyDailyEnabled" :threshold="props.quotaNotifyDailyThreshold" + :threshold-type="props.quotaNotifyDailyThresholdType" @update:enabled="emit('update:quotaNotifyDailyEnabled', $event)" @update:threshold="emit('update:quotaNotifyDailyThreshold', $event)" + @update:threshold-type="emit('update:quotaNotifyDailyThresholdType', $event)" /> @@ -292,8 +303,10 @@ const onWeeklyModeChange = (e: Event) => { v-if="weeklyLimit && weeklyLimit > 0" :enabled="props.quotaNotifyWeeklyEnabled" :threshold="props.quotaNotifyWeeklyThreshold" + :threshold-type="props.quotaNotifyWeeklyThresholdType" @update:enabled="emit('update:quotaNotifyWeeklyEnabled', $event)" @update:threshold="emit('update:quotaNotifyWeeklyThreshold', $event)" + @update:threshold-type="emit('update:quotaNotifyWeeklyThresholdType', $event)" /> @@ -330,8 +343,10 @@ const onWeeklyModeChange = (e: Event) => { v-if="totalLimit && totalLimit > 0" :enabled="props.quotaNotifyTotalEnabled" :threshold="props.quotaNotifyTotalThreshold" + :threshold-type="props.quotaNotifyTotalThresholdType" @update:enabled="emit('update:quotaNotifyTotalEnabled', $event)" @update:threshold="emit('update:quotaNotifyTotalThreshold', $event)" + @update:threshold-type="emit('update:quotaNotifyTotalThresholdType', $event)" /> diff --git a/frontend/src/components/account/QuotaNotifyToggle.vue b/frontend/src/components/account/QuotaNotifyToggle.vue index 4634f5b1..b1c22fe2 100644 --- a/frontend/src/components/account/QuotaNotifyToggle.vue +++ b/frontend/src/components/account/QuotaNotifyToggle.vue @@ -6,12 +6,18 @@ const { t } = useI18n() defineProps<{ enabled: boolean | null threshold: number | null + thresholdType: string | null // "fixed" (default) or "percentage" }>() const emit = defineEmits<{ 'update:enabled': [value: boolean | null] 'update:threshold': [value: number | null] + 'update:thresholdType': [value: string | null] }>() + +function toggleType(current: string | null) { + emit('update:thresholdType', current === 'percentage' ? 'fixed' : 'percentage') +} \ No newline at end of file From 422807514c9f147c12f09b86decb9cbbc9312627 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 20:40:31 +0800 Subject: [PATCH 35/88] fix(notify): add duplicate email check message and improve extra email UX --- .../user/profile/ProfileBalanceNotifyCard.vue | 12 +++++++----- frontend/src/i18n/locales/en.ts | 1 + frontend/src/i18n/locales/zh.ts | 1 + 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue index cfe1b332..797616bb 100644 --- a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue +++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue @@ -49,14 +49,16 @@ - +
{{ email }} - +
+ +
@@ -185,7 +187,7 @@ function addPendingEmail() { const email = newEmail.value.trim() if (!email) return if (email === props.userEmail || extraEmails.value.includes(email) || pendingEmails.value.some(p => p.email === email)) { - appStore.showError(t('common.error')) + appStore.showError(t('profile.balanceNotify.emailDuplicate')) return } pendingEmails.value.push({ email, codeSent: false, code: '', sending: false, verifying: false, countdown: 0, timer: null }) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 1b4cb9ea..ff77837e 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -929,6 +929,7 @@ export default { verifySuccess: 'Email added successfully', removeEmail: 'Remove', removeSuccess: 'Email removed', + emailDuplicate: 'This email already exists', } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index ca091dd5..9eabb465 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -933,6 +933,7 @@ export default { verifySuccess: '邮箱添加成功', removeEmail: '移除', removeSuccess: '邮箱已移除', + emailDuplicate: '该邮箱已存在', } }, From 61aa197b0bb3905e5e5b582eeb25c25995670f96 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 12 Apr 2026 20:45:58 +0800 Subject: [PATCH 36/88] fix(notify): add explicit save button for balance threshold Replace blur-based auto-save with an explicit Save button so users know when their threshold is persisted. Shows success toast on save. --- .../user/profile/ProfileBalanceNotifyCard.vue | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue index 797616bb..589cc9e8 100644 --- a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue +++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue @@ -19,7 +19,7 @@ @@ -124,12 +130,15 @@ From 11c4606874d1ade3219ce75bf405463118998172 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 02:28:31 +0800 Subject: [PATCH 40/88] fix(channel): use upstream model for account stats pricing and remove channel pricing fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - resolveAccountStatsCost now uses the final upstream model (after account-level mapping) to match custom pricing rules, fixing the issue where requested model (e.g. claude-sonnet-4-5) didn't match rules configured for upstream model (e.g. claude-opus-4-6) - Remove tryChannelPricing fallback — only custom rules are applied, unmatched requests use default formula (total_cost × rate) - Remove unused billingService and serviceTier parameters - Update description: "启用后将支持自定义账号统计的模型价格" --- .../internal/service/account_stats_pricing.go | 38 ++++--------------- backend/internal/service/gateway_service.go | 13 ++++--- .../service/openai_gateway_service.go | 12 ++++-- frontend/src/i18n/locales/en.ts | 2 +- frontend/src/i18n/locales/zh.ts | 2 +- 5 files changed, 25 insertions(+), 42 deletions(-) diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 86f98a12..4a896d9f 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -8,23 +8,18 @@ import ( // resolveAccountStatsCost 计算账号统计定价费用。 // 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。 -// -// 匹配优先级(先命中为准): -// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历) -// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时) -// 3. nil → 走默认公式 +// 仅匹配自定义规则(AccountStatsPricingRules),按数组顺序先命中为准。 +// upstreamModel 是最终发往上游的模型 ID,用于匹配自定义规则中的模型定价。 func resolveAccountStatsCost( ctx context.Context, channelService *ChannelService, - billingService *BillingService, accountID int64, groupID int64, - billingModel string, + upstreamModel string, tokens UsageTokens, requestCount int, - serviceTier string, ) *float64 { - if channelService == nil || billingService == nil { + if channelService == nil || upstreamModel == "" { return nil } channel, err := channelService.GetChannelForGroup(ctx, groupID) @@ -33,22 +28,15 @@ func resolveAccountStatsCost( } platform := channelService.GetGroupPlatform(ctx, groupID) - modelLower := strings.ToLower(billingModel) - - // 优先级 1:自定义规则 - if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil { - return cost - } - - // 优先级 2:渠道已有模型定价 - return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount) + return tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount) } // tryCustomRules 遍历自定义规则,按数组顺序先命中为准。 func tryCustomRules( channel *Channel, accountID, groupID int64, - platform, modelLower string, tokens UsageTokens, requestCount int, + platform, model string, tokens UsageTokens, requestCount int, ) *float64 { + modelLower := strings.ToLower(model) for _, rule := range channel.AccountStatsPricingRules { if !matchAccountStatsRule(&rule, accountID, groupID) { continue @@ -62,18 +50,6 @@ func tryCustomRules( return nil } -// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。 -func tryChannelPricing( - ctx context.Context, channelService *ChannelService, - groupID int64, billingModel string, tokens UsageTokens, requestCount int, -) *float64 { - pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel) - if pricing == nil { - return nil - } - return calculateStatsCost(pricing, tokens, requestCount) -} - // matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。 // 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。 // 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d68a7771..70dd9b52 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7581,11 +7581,15 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) - // 计算账号统计定价费用 + // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) if apiKey.GroupID != nil { + upstreamModel := result.UpstreamModel + if upstreamModel == "" { + upstreamModel = result.Model + } usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, s.channelService, s.billingService, - account.ID, *apiKey.GroupID, billingModel, + ctx, s.channelService, + account.ID, *apiKey.GroupID, upstreamModel, UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -7593,8 +7597,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage CacheReadTokens: result.Usage.CacheReadInputTokens, ImageOutputTokens: result.Usage.ImageOutputTokens, }, - 1, // requestCount - "", // serviceTier: Anthropic 平台不使用 service tier + 1, // requestCount ) } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 70abd4ce..98258cd0 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4573,12 +4573,16 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.SubscriptionID = &subscription.ID } - // 计算账号统计定价费用 + // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) if apiKey.GroupID != nil { + statsModel := result.UpstreamModel + if statsModel == "" { + statsModel = result.Model + } usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, s.channelService, s.billingService, - account.ID, *apiKey.GroupID, billingModel, - tokens, 1, serviceTier, + ctx, s.channelService, + account.ID, *apiKey.GroupID, statsModel, + tokens, 1, ) } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 07ae0c3d..f45a02f6 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1877,7 +1877,7 @@ export default { pricingEntry: 'Pricing Entry', noModels: 'No models added', applyPricingToAccountStats: 'Apply Pricing to Account Stats', - applyPricingToAccountStatsDesc: 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.', + applyPricingToAccountStatsDesc: 'When enabled, custom account stats model pricing rules will be applied.', accountStatsPricingRules: 'Custom Account Stats Pricing Rules', addRule: 'Add Rule', noRulesConfigured: 'No custom rules configured. Channel model pricing above will be used.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 64d42c12..61d43a37 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1956,7 +1956,7 @@ export default { pricingEntry: '定价配置', noModels: '未添加模型', applyPricingToAccountStats: '应用模型定价到账号统计', - applyPricingToAccountStatsDesc: '启用后,账号统计费用将使用渠道模型定价计算。账号自身的统计倍率仍然生效。', + applyPricingToAccountStatsDesc: '启用后将支持自定义账号统计的模型价格', accountStatsPricingRules: '自定义账号统计定价规则', addRule: '添加规则', noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。', From 1262654d9785673cfda3071a714f3a61734933e7 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 11:37:08 +0800 Subject: [PATCH 41/88] feat: WebSearch tri-state, account stats pricing fix, quota cache fix, usage tooltip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WebSearch tri-state switch: - Account-level web_search_emulation changed from bool to tri-state string: "default" (follow channel) / "enabled" / "disabled" - shouldEmulateWebSearch checks channel config when account is "default" - SQL migration converts old bool values - Frontend select replaces toggle in Edit/CreateAccountModal Account stats pricing: - resolveAccountStatsCost uses upstream model (post-mapping) for matching - Priority: custom rules → model pricing file (when toggle on) → default - Custom rules always configurable, independent of toggle - Account ID field changed to searchable selector filtered by platform - Description updated to reflect new behavior Quota notification cache fix: - CheckAccountQuotaAfterIncrement fetches real-time account from DB - Reconstructs pre-increment usage for accurate threshold crossing detection - New AccountQuotaReader interface (minimal: GetByID only) Usage tooltip: - Per-request/image billing shows per-request price instead of $0 token price - Token billing continues to show input/output price per million tokens --- backend/cmd/server/wire_gen.go | 2 +- backend/internal/handler/gateway_handler.go | 3 + backend/internal/service/account.go | 29 +++- .../internal/service/account_stats_pricing.go | 41 ++++- .../service/balance_notify_service.go | 43 +++++- backend/internal/service/gateway_request.go | 3 + backend/internal/service/gateway_service.go | 4 +- .../service/gateway_websearch_emulation.go | 24 ++- .../service/openai_gateway_service.go | 2 +- backend/internal/service/wire.go | 4 +- ...igrate_websearch_emulation_to_tristate.sql | 11 ++ .../components/account/CreateAccountModal.vue | 21 +-- .../components/account/EditAccountModal.vue | 27 +++- .../src/components/admin/usage/UsageTable.vue | 22 ++- frontend/src/i18n/locales/en.ts | 12 +- frontend/src/i18n/locales/zh.ts | 12 +- frontend/src/views/admin/ChannelsView.vue | 143 ++++++++++++++++-- frontend/src/views/user/UsageView.vue | 22 ++- 18 files changed, 346 insertions(+), 79 deletions(-) create mode 100644 backend/migrations/105_migrate_websearch_emulation_to_tristate.sql diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 24ee02fd..69daeecf 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -176,7 +176,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { channelRepository := repository.NewChannelRepository(db) channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) - balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository) + balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 59619d50..8ec54420 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + // 设置请求所属分组 ID(用于渠道级功能判断,如 WebSearch 模拟) + parsedReq.GroupID = apiKey.GroupID + // 计算粘性会话hash parsedReq.SessionContext = &service.SessionContext{ ClientIP: ip.GetClientIP(c), diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 6e5a768f..4d933986 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1169,15 +1169,30 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool { return ok && enabled } -// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。 -// 字段:accounts.extra.web_search_emulation。 -// 字段缺失或类型不正确时,按 false(关闭)处理。 -func (a *Account) IsWebSearchEmulationEnabled() bool { +// WebSearch 模拟三态常量 +const ( + WebSearchModeDefault = "default" // 跟随渠道配置 + WebSearchModeEnabled = "enabled" // 强制开启 + WebSearchModeDisabled = "disabled" // 强制关闭 +) + +// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。 +// 三态:default(跟随渠道)/ enabled(强制开启)/ disabled(强制关闭)。 +// 旧 bool 值需通过 SQL 迁移脚本转换,Go 代码不做兼容。 +func (a *Account) GetWebSearchEmulationMode() string { if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil { - return false + return WebSearchModeDefault + } + mode, ok := a.Extra[featureKeyWebSearchEmulation].(string) + if !ok { + return WebSearchModeDefault + } + switch mode { + case WebSearchModeEnabled, WebSearchModeDisabled: + return mode + default: + return WebSearchModeDefault } - enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool) - return ok && enabled } // IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。 diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 4a896d9f..e88f7f8c 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -8,11 +8,17 @@ import ( // resolveAccountStatsCost 计算账号统计定价费用。 // 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。 -// 仅匹配自定义规则(AccountStatsPricingRules),按数组顺序先命中为准。 -// upstreamModel 是最终发往上游的模型 ID,用于匹配自定义规则中的模型定价。 +// +// 优先级(先命中为准): +// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关) +// 2. ApplyPricingToAccountStats 启用时,用模型定价文件(LiteLLM)中上游模型的标准价格计算 +// 3. nil → 走默认公式 +// +// upstreamModel 是最终发往上游的模型 ID。 func resolveAccountStatsCost( ctx context.Context, channelService *ChannelService, + billingService *BillingService, accountID int64, groupID int64, upstreamModel string, @@ -23,12 +29,39 @@ func resolveAccountStatsCost( return nil } channel, err := channelService.GetChannelForGroup(ctx, groupID) - if err != nil || channel == nil || !channel.ApplyPricingToAccountStats { + if err != nil || channel == nil { return nil } platform := channelService.GetGroupPlatform(ctx, groupID) - return tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount) + + // 优先级 1:自定义规则(始终尝试) + if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil { + return cost + } + + // 优先级 2:模型定价文件(LiteLLM/fallback)中上游模型的标准价格 + if channel.ApplyPricingToAccountStats && billingService != nil { + return tryModelFilePricing(billingService, upstreamModel, tokens) + } + + return nil +} + +// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。 +func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 { + pricing, err := billingService.GetModelPricing(model) + if err != nil || pricing == nil { + return nil + } + cost := float64(tokens.InputTokens)*pricing.InputPricePerToken + + float64(tokens.OutputTokens)*pricing.OutputPricePerToken + + float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken + + float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken + if cost <= 0 { + return nil + } + return &cost } // tryCustomRules 遍历自定义规则,按数组顺序先命中为准。 diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 053451e1..23411ed5 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -27,17 +27,24 @@ var quotaDimLabels = map[string]string{ quotaDimTotal: "总限额 / Total", } +// AccountQuotaReader provides read access to account quota data. +type AccountQuotaReader interface { + GetByID(ctx context.Context, id int64) (*Account, error) +} + // BalanceNotifyService handles balance and quota threshold notifications. type BalanceNotifyService struct { emailService *EmailService settingRepo SettingRepository + accountRepo AccountQuotaReader } // NewBalanceNotifyService creates a new BalanceNotifyService. -func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService { +func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountQuotaReader) *BalanceNotifyService { return &BalanceNotifyService{ emailService: emailService, settingRepo: settingRepo, + accountRepo: accountRepo, } } @@ -110,7 +117,7 @@ func buildQuotaDims(account *Account) []quotaDim { } // CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold. -// The account's Extra fields contain pre-increment usage values. +// It fetches real-time quota usage from DB to avoid stale snapshot values. func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64) { if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 { return @@ -123,8 +130,29 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte return } + freshAccount := s.fetchFreshAccount(ctx, account) siteName := s.getSiteName(ctx) - for _, dim := range buildQuotaDims(account) { + s.checkQuotaDimCrossings(freshAccount, cost, adminEmails, siteName) +} + +// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error. +func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *Account) *Account { + if s.accountRepo == nil { + return snapshot + } + fresh, err := s.accountRepo.GetByID(ctx, snapshot.ID) + if err != nil { + slog.Warn("failed to fetch fresh account for quota notify, using snapshot", + "account_id", snapshot.ID, "error", err) + return snapshot + } + return fresh +} + +// checkQuotaDimCrossings iterates quota dimensions and sends alerts for threshold crossings. +// freshAccount has post-increment values; oldUsed is reconstructed as freshUsed - cost. +func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cost float64, adminEmails []string, siteName string) { + for _, dim := range buildQuotaDims(freshAccount) { if !dim.enabled || dim.threshold <= 0 { continue } @@ -132,9 +160,12 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte if effectiveThreshold <= 0 { continue } - newUsed := dim.oldUsed + cost - if dim.oldUsed < effectiveThreshold && newUsed >= effectiveThreshold { - s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, effectiveThreshold, siteName) + // dim.oldUsed is actually the post-increment value from fresh DB data; + // reconstruct pre-increment value to detect threshold crossing. + newUsed := dim.oldUsed + oldUsed := dim.oldUsed - cost + if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold { + s.asyncSendQuotaAlert(adminEmails, freshAccount.Name, dim, newUsed, effectiveThreshold, siteName) } } } diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index e2badfed..55cb2c84 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -75,6 +75,9 @@ type ParsedRequest struct { MaxTokens int // max_tokens 值(用于探测请求拦截) SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) + // GroupID 请求所属分组 ID(来自 API Key) + GroupID *int64 + // OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁) // 流式请求在收到 2xx 响应头后调用,避免持锁等流完成 OnUpstreamAccepted func() diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 70dd9b52..5267156d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3789,7 +3789,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应 - if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) { + if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) { return s.handleWebSearchEmulation(ctx, c, account, parsed) } @@ -7588,7 +7588,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage upstreamModel = result.Model } usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, s.channelService, + ctx, s.channelService, s.billingService, account.ID, *apiKey.GroupID, upstreamModel, UsageTokens{ InputTokens: result.Usage.InputTokens, diff --git a/backend/internal/service/gateway_websearch_emulation.go b/backend/internal/service/gateway_websearch_emulation.go index b3a4aa69..0d0b5480 100644 --- a/backend/internal/service/gateway_websearch_emulation.go +++ b/backend/internal/service/gateway_websearch_emulation.go @@ -49,10 +49,9 @@ func getWebSearchManager() *websearch.Manager { // shouldEmulateWebSearch checks whether a request should be intercepted. // -// Judgment chain: manager exists → only web_search tool → global enabled → account enabled. -// Note: channel-level control is enforced via the account's extra field; the channel toggle -// in the admin UI sets the account's flag for all accounts in that channel's groups. -func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool { +// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled. +// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel). +func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool { if getWebSearchManager() == nil { return false } @@ -62,10 +61,23 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac if !s.settingService.IsWebSearchEmulationEnabled(ctx) { return false } - if !account.IsWebSearchEmulationEnabled() { + + mode := account.GetWebSearchEmulationMode() + switch mode { + case WebSearchModeEnabled: + return true + case WebSearchModeDisabled: return false + default: // "default" → follow channel config + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil || ch == nil { + return false + } + return ch.IsWebSearchEmulationEnabled(account.Platform) } - return true } // isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool. diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 98258cd0..e060b981 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4580,7 +4580,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec statsModel = result.Model } usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, s.channelService, + ctx, s.channelService, s.billingService, account.ID, *apiKey.GroupID, statsModel, tokens, 1, ) diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b4e33039..9f33c46a 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -476,8 +476,8 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep } // ProvideBalanceNotifyService creates BalanceNotifyService -func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService { - return NewBalanceNotifyService(emailService, settingRepo) +func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService { + return NewBalanceNotifyService(emailService, settingRepo, accountRepo) } // ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService. diff --git a/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql b/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql new file mode 100644 index 00000000..745e58df --- /dev/null +++ b/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql @@ -0,0 +1,11 @@ +-- Convert old boolean web_search_emulation to tri-state string +-- true → "enabled", false → remove key (becomes "default") +UPDATE accounts +SET extra = (extra - 'web_search_emulation') || jsonb_build_object('web_search_emulation', 'enabled') +WHERE extra ? 'web_search_emulation' + AND extra->>'web_search_emulation' = 'true'; + +UPDATE accounts +SET extra = extra - 'web_search_emulation' +WHERE extra ? 'web_search_emulation' + AND extra->>'web_search_emulation' = 'false'; diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index e0dbce61..e83e061e 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2337,7 +2337,11 @@ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}

- + @@ -2846,7 +2850,6 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import Select from '@/components/common/Select.vue' import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' -import Toggle from '@/components/common/Toggle.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' @@ -2997,7 +3000,7 @@ const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) -const webSearchEmulationEnabled = ref(false) +const webSearchEmulationMode = ref('default') const webSearchGlobalEnabled = ref(false) // Load web search global state once @@ -3331,7 +3334,7 @@ watch( } if (newPlatform !== 'anthropic') { anthropicPassthroughEnabled.value = false - webSearchEmulationEnabled.value = false + webSearchEmulationMode.value = 'default' } // Reset OAuth states oauth.resetState() @@ -3351,7 +3354,7 @@ watch( } if (platform !== 'anthropic' || category !== 'apikey') { anthropicPassthroughEnabled.value = false - webSearchEmulationEnabled.value = false + webSearchEmulationMode.value = 'default' } } ) @@ -3716,7 +3719,7 @@ const resetForm = () => { openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF codexCLIOnlyEnabled.value = false anthropicPassthroughEnabled.value = false - webSearchEmulationEnabled.value = false + webSearchEmulationMode.value = 'default' // Reset quota control state windowCostEnabled.value = false windowCostLimit.value = null @@ -3804,10 +3807,10 @@ const buildAnthropicExtra = (base?: Record): Record 0 ? extra : undefined diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index abb9569e..74e5fc1f 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1161,7 +1161,11 @@ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}

- + @@ -1844,7 +1848,6 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import Select from '@/components/common/Select.vue' import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' -import Toggle from '@/components/common/Toggle.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' @@ -1986,7 +1989,7 @@ const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) -const webSearchEmulationEnabled = ref(false) +const webSearchEmulationMode = ref('default') const webSearchGlobalEnabled = ref(false) // Load web search global state once @@ -2171,7 +2174,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF codexCLIOnlyEnabled.value = false anthropicPassthroughEnabled.value = false - webSearchEmulationEnabled.value = false + webSearchEmulationMode.value = 'default' if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) { openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { @@ -2192,7 +2195,15 @@ const syncFormFromAccount = (newAccount: Account | null) => { } if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') { anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true - webSearchEmulationEnabled.value = extra?.web_search_emulation === true + // 三态:string "default"/"enabled"/"disabled",向后兼容旧 bool + const wsVal = extra?.web_search_emulation + if (wsVal === 'enabled' || wsVal === 'disabled') { + webSearchEmulationMode.value = wsVal + } else if (wsVal === true) { + webSearchEmulationMode.value = 'enabled' + } else { + webSearchEmulationMode.value = 'default' + } } // Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above) @@ -3180,10 +3191,10 @@ const handleSubmit = async () => { } else { delete newExtra.anthropic_passthrough } - if (webSearchEmulationEnabled.value) { - newExtra.web_search_emulation = true - } else { + if (webSearchEmulationMode.value === 'default') { delete newExtra.web_search_emulation + } else { + newExtra.web_search_emulation = webSearchEmulationMode.value } updatePayload.extra = newExtra } diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index f4494e69..92c8dd34 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -279,13 +279,21 @@ {{ t('admin.usage.outputCost') }} ${{ tooltipData.output_cost.toFixed(6) }} -
- {{ t('usage.inputTokenPrice') }} - {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }} -
-
- {{ t('usage.outputTokenPrice') }} - {{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }} + + + +
+ {{ tooltipData.billing_mode === 'image' ? t('usage.imageUnitPrice') : t('usage.unitPrice') }} + ${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}
{{ t('admin.usage.cacheCreationCost') }} diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index f45a02f6..056bdde0 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -774,6 +774,8 @@ export default { inputTokenPrice: 'Input price', outputTokenPrice: 'Output price', perMillionTokens: '/ 1M tokens', + unitPrice: 'Per-request price', + imageUnitPrice: 'Per-image price', cacheRead: 'Read', cacheWrite: 'Write', serviceTier: 'Service tier', @@ -1877,14 +1879,15 @@ export default { pricingEntry: 'Pricing Entry', noModels: 'No models added', applyPricingToAccountStats: 'Apply Pricing to Account Stats', - applyPricingToAccountStatsDesc: 'When enabled, custom account stats model pricing rules will be applied.', + applyPricingToAccountStatsDesc: 'When enabled, requests not matched by custom rules will use standard model pricing for account stats calculation', accountStatsPricingRules: 'Custom Account Stats Pricing Rules', addRule: 'Add Rule', noRulesConfigured: 'No custom rules configured. Channel model pricing above will be used.', ruleName: 'Rule name (optional)', ruleGroups: 'Groups', - ruleAccounts: 'Account IDs', - ruleAccountsPlaceholder: 'Enter account IDs, comma-separated', + ruleAccounts: 'Accounts', + searchAccountPlaceholder: 'Search accounts...', + ruleAccountsHint: 'Leave empty to match all accounts', ruleModelPricing: 'Model Pricing', noGroupsInChannel: 'No groups selected in platform tabs above' } @@ -2380,6 +2383,9 @@ export default { webSearchEmulation: 'Web Search Emulation', webSearchEmulationDesc: 'Enable web search emulation for this API Key account. When a pure web_search request is detected, the gateway calls a third-party search API and constructs the response locally.', + webSearchDefault: 'Default (follow channel)', + webSearchEnabled: 'Enabled', + webSearchDisabled: 'Disabled', }, modelRestriction: 'Model Restriction (Optional)', modelWhitelist: 'Model Whitelist', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 61d43a37..47a1f8d4 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -778,6 +778,8 @@ export default { inputTokenPrice: '输入单价', outputTokenPrice: '输出单价', perMillionTokens: '/ 1M Token', + unitPrice: '单次价格', + imageUnitPrice: '单张价格', cacheRead: '读取', cacheWrite: '写入', serviceTier: '服务档位', @@ -1956,14 +1958,15 @@ export default { pricingEntry: '定价配置', noModels: '未添加模型', applyPricingToAccountStats: '应用模型定价到账号统计', - applyPricingToAccountStatsDesc: '启用后将支持自定义账号统计的模型价格', + applyPricingToAccountStatsDesc: '启用后,未被自定义规则匹配的请求将使用模型定价文件中的标准价格计算账号统计费用', accountStatsPricingRules: '自定义账号统计定价规则', addRule: '添加规则', noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。', ruleName: '规则名称(可选)', ruleGroups: '分组', - ruleAccounts: '账号 ID', - ruleAccountsPlaceholder: '输入账号 ID,逗号分隔', + ruleAccounts: '账号', + searchAccountPlaceholder: '搜索账号...', + ruleAccountsHint: '留空表示匹配所有账号', ruleModelPricing: '模型定价', noGroupsInChannel: '上方平台标签页中未选择分组' } @@ -2527,6 +2530,9 @@ export default { webSearchEmulation: 'Web Search 模拟', webSearchEmulationDesc: '为该 API Key 账号启用 web search 模拟。客户端发送纯 web_search 请求时,由网关调用第三方搜索 API 并构造响应返回。', + webSearchDefault: '默认(跟随渠道)', + webSearchEnabled: '开启', + webSearchDisabled: '关闭', }, modelRestriction: '模型限制(可选)', modelWhitelist: '模型白名单', diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 2e45ba42..5be45cb5 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -413,8 +413,8 @@
- -
+ +

{{ t('admin.channels.form.accountStatsPricingRules') }} @@ -474,12 +474,51 @@
- + +
+ + {{ getRuleAccountLabel(accountId) }} + + +
+ + +

+ {{ t('admin.channels.form.ruleAccountsHint') }} +

@@ -569,6 +608,7 @@ import PlatformIcon from '@/components/common/PlatformIcon.vue' import Toggle from '@/components/common/Toggle.vue' import PricingEntryCard from '@/components/admin/channel/PricingEntryCard.vue' import { getPersistedPageSize } from '@/composables/usePersistedPageSize' +import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch' const { t } = useI18n() const appStore = useAppStore() @@ -852,6 +892,9 @@ function addRulePricingEntry(ruleIndex: number) { function removeAccountStatsRule(ruleIndex: number) { form.account_stats_pricing_rules.splice(ruleIndex, 1) + // Clear all search state since indices shift after removal + ruleAccountSearchRunner.clearAll() + clearAllRuleAccountSearchState() } function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) { @@ -863,11 +906,78 @@ function getGroupNameById(groupId: number): string { return group ? group.name : `#${groupId}` } -function parseAccountIdsInput(value: string): number[] { - return value - .split(',') - .map(s => parseInt(s.trim())) - .filter(n => !isNaN(n) && n > 0) +// ── Account search for pricing rules ── +interface SimpleAccount { id: number; name: string } + +const ruleAccountSearchKeyword = ref>({}) +const ruleAccountSearchResults = ref>({}) +const showRuleAccountDropdown = ref>({}) +// Cache: account ID → name, populated when search results are selected +const ruleAccountNameCache = ref>({}) + +const ruleAccountSearchRunner = useKeyedDebouncedSearch({ + delay: 300, + search: async (keyword, { key, signal }) => { + const platform = key.split('-')[0] + const res = await adminAPI.accounts.list(1, 20, { platform, search: keyword }, { signal }) + return res.items.map(a => ({ id: a.id, name: a.name })) + }, + onSuccess: (key, result) => { ruleAccountSearchResults.value[key] = result }, + onError: (key) => { ruleAccountSearchResults.value[key] = [] }, +}) + +function onRuleAccountSearchInput(platform: string, ruleIndex: number) { + const key = `${platform}-${ruleIndex}` + showRuleAccountDropdown.value[key] = true + ruleAccountSearchRunner.trigger(key, ruleAccountSearchKeyword.value[key] || '') +} + +function onRuleAccountSearchFocus(platform: string, ruleIndex: number) { + const key = `${platform}-${ruleIndex}` + showRuleAccountDropdown.value[key] = true + if (!ruleAccountSearchResults.value[key]?.length) { + ruleAccountSearchRunner.trigger(key, ruleAccountSearchKeyword.value[key] || '') + } +} + +function selectRuleAccount( + rule: { account_ids: number[] }, + account: SimpleAccount, + platform: string, + ruleIndex: number, +) { + if (!rule.account_ids.includes(account.id)) { + rule.account_ids.push(account.id) + ruleAccountNameCache.value[account.id] = account.name + } + const key = `${platform}-${ruleIndex}` + ruleAccountSearchKeyword.value[key] = '' + showRuleAccountDropdown.value[key] = false +} + +function removeRuleAccount(rule: { account_ids: number[] }, accountId: number) { + const idx = rule.account_ids.indexOf(accountId) + if (idx !== -1) rule.account_ids.splice(idx, 1) +} + +function getRuleAccountLabel(accountId: number): string { + const name = ruleAccountNameCache.value[accountId] + return name ? `${name} #${accountId}` : `#${accountId}` +} + +function handleRuleAccountClickOutside(event: MouseEvent) { + const target = event.target as HTMLElement + if (!target.closest('.rule-account-search-container')) { + Object.keys(showRuleAccountDropdown.value).forEach(key => { + showRuleAccountDropdown.value[key] = false + }) + } +} + +function clearAllRuleAccountSearchState() { + ruleAccountSearchKeyword.value = {} + ruleAccountSearchResults.value = {} + showRuleAccountDropdown.value = {} } function accountStatsRulesToAPI(): AccountStatsPricingRule[] { @@ -1093,6 +1203,9 @@ function resetForm() { form.apply_pricing_to_account_stats = false form.account_stats_pricing_rules = [] activeTab.value = 'basic' + ruleAccountSearchRunner.clearAll() + clearAllRuleAccountSearchState() + ruleAccountNameCache.value = {} } async function openCreateDialog() { @@ -1313,11 +1426,15 @@ onMounted(() => { loadChannels() loadGroups() loadWebSearchGlobalState() + document.addEventListener('click', handleRuleAccountClickOutside) }) onUnmounted(() => { clearTimeout(searchTimeout) abortController?.abort() + document.removeEventListener('click', handleRuleAccountClickOutside) + ruleAccountSearchRunner.clearAll() + clearAllRuleAccountSearchState() }) diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue index 6cb367ed..2ec0fea5 100644 --- a/frontend/src/views/user/UsageView.vue +++ b/frontend/src/views/user/UsageView.vue @@ -447,13 +447,21 @@ {{ t('admin.usage.outputCost') }} ${{ tooltipData.output_cost.toFixed(6) }}
-
- {{ t('usage.inputTokenPrice') }} - {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }} -
-
- {{ t('usage.outputTokenPrice') }} - {{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }} + + + +
+ {{ tooltipData.billing_mode === 'image' ? t('usage.imageUnitPrice') : t('usage.unitPrice') }} + ${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}
{{ t('admin.usage.cacheCreationCost') }} From a68df457d80dd67b173df2a74fd69de1206c0f2e Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 12:07:09 +0800 Subject: [PATCH 42/88] fix: address audit findings across websearch, notify, and channel pricing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backend fixes: - Fix balance notify ignoring percentage threshold type (was treating percentage value as fixed USD amount) - Remove dead code parseJSONStringArray - Add ImageOutputTokens to tryModelFilePricing calculation - Unify zero-value check: cost == 0 → cost <= 0 in calculateTokenStatsCost - Use MarshalNotifyEmails instead of json.Marshal for consistency - Rename quotaDim.oldUsed → currentUsed for clarity - Extract HTML email templates to const variables (function ≤30 lines) Test fixes: - Rewrite account_websearch_test.go for GetWebSearchEmulationMode tri-state - Add 6 tryModelFilePricing test cases Frontend fixes: - Replace hardcoded '未命名' with i18n key - Extract getBillingModeLabel/getBillingModeBadgeClass to shared utils - Replace inline type with imported NotifyEmailEntry - Pass platform to AccountStats pricing rules via inferRulePlatform() - Add billing mode constants (BILLING_MODE_TOKEN/PER_REQUEST/IMAGE) --- .../internal/service/account_stats_pricing.go | 5 +- .../service/account_stats_pricing_test.go | 99 +++++++++++++++++++ .../service/account_websearch_test.go | 96 ++++++++++++------ .../service/balance_notify_service.go | 79 ++++++++------- backend/internal/service/setting_service.go | 6 +- .../src/components/admin/usage/UsageTable.vue | 14 +-- frontend/src/i18n/locales/en.ts | 3 +- frontend/src/i18n/locales/zh.ts | 3 +- frontend/src/utils/billingMode.ts | 19 ++++ frontend/src/views/admin/ChannelsView.vue | 52 ++++++---- frontend/src/views/admin/SettingsView.vue | 4 +- frontend/src/views/user/UsageView.vue | 16 +-- 12 files changed, 275 insertions(+), 121 deletions(-) create mode 100644 frontend/src/utils/billingMode.ts diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index e88f7f8c..cbe9c76c 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -57,7 +57,8 @@ func tryModelFilePricing(billingService *BillingService, model string, tokens Us cost := float64(tokens.InputTokens)*pricing.InputPricePerToken + float64(tokens.OutputTokens)*pricing.OutputPricePerToken + float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken + - float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken + float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken + + float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken if cost <= 0 { return nil } @@ -194,7 +195,7 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) * float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) + float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) + float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice) - if cost == 0 { + if cost <= 0 { return nil } return &cost diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index bc3db251..bf9da978 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -428,3 +428,102 @@ func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) { require.NotNil(t, result) require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2 } + +// --------------------------------------------------------------------------- +// tryModelFilePricing +// --------------------------------------------------------------------------- + +// newTestBillingServiceWithPrices creates a BillingService with pre-populated +// fallback prices for testing. No config or pricing service is needed. +// The key must match what getFallbackPricing resolves to for a given model name. +// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4". +func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService { + return &BillingService{ + fallbackPrices: prices, + } +} + +func TestTryModelFilePricing_Success(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestTryModelFilePricing_PricingNotFound(t *testing.T) { + // "nonexistent-model" does not match any fallback pattern + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{}) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + result := tryModelFilePricing(bs, "nonexistent-model", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_NilFallback(t *testing.T) { + // getFallbackPricing returns nil when key maps to nil + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": nil, + }) + tokens := UsageTokens{InputTokens: 100} + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_ZeroCost(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + tokens := UsageTokens{} // all zero tokens → cost = 0 → nil + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_WithImageOutput(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + ImageOutputPricePerToken: 0.01, + }, + }) + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + ImageOutputTokens: 10, + } + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3 + require.InDelta(t, 0.3, *result, 1e-12) +} + +func TestTryModelFilePricing_WithCacheTokens(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + CacheCreationPricePerToken: 0.003, + CacheReadPricePerToken: 0.0005, + }, + }) + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + CacheCreationTokens: 200, + CacheReadTokens: 300, + } + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005 + // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95 + require.InDelta(t, 0.95, *result, 1e-12) +} diff --git a/backend/internal/service/account_websearch_test.go b/backend/internal/service/account_websearch_test.go index fe742ebf..b4d23c6b 100644 --- a/backend/internal/service/account_websearch_test.go +++ b/backend/internal/service/account_websearch_test.go @@ -1,3 +1,5 @@ +//go:build unit + package service import ( @@ -6,66 +8,98 @@ import ( "github.com/stretchr/testify/require" ) -func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) { +func TestGetWebSearchEmulationMode_Enabled(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, + } + require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_Disabled(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"}, + } + require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_Default(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "default"}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) { a := &Account{ Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: map[string]any{featureKeyWebSearchEmulation: true}, } - require.True(t, a.IsWebSearchEmulationEnabled()) + // bool is not a string, type assertion fails → default + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } -func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) { +func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) { a := &Account{ Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: map[string]any{featureKeyWebSearchEmulation: false}, } - require.False(t, a.IsWebSearchEmulationEnabled()) + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } -func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) { +func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) { + var a *Account + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: nil, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_MissingField(t *testing.T) { a := &Account{ Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: map[string]any{}, } - require.False(t, a.IsWebSearchEmulationEnabled()) + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } -func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) { - a := &Account{ - Platform: PlatformAnthropic, - Type: AccountTypeAPIKey, - Extra: map[string]any{featureKeyWebSearchEmulation: "true"}, - } - require.False(t, a.IsWebSearchEmulationEnabled()) -} - -func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) { - a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil} - require.False(t, a.IsWebSearchEmulationEnabled()) -} - -func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) { - var a *Account - require.False(t, a.IsWebSearchEmulationEnabled()) -} - -func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) { +func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) { a := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, - Extra: map[string]any{featureKeyWebSearchEmulation: true}, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, } - require.False(t, a.IsWebSearchEmulationEnabled()) + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } -func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) { +func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) { a := &Account{ Platform: PlatformAnthropic, Type: AccountTypeOAuth, - Extra: map[string]any{featureKeyWebSearchEmulation: true}, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, } - require.False(t, a.IsWebSearchEmulationEnabled()) + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 23411ed5..1e4d8ff6 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "encoding/json" "fmt" "html" "log/slog" @@ -14,6 +13,10 @@ import ( const ( emailSendTimeout = 30 * time.Second + // Threshold type values + thresholdTypeFixed = "fixed" + thresholdTypePercentage = "percentage" + // Quota dimension labels quotaDimDaily = "daily" quotaDimWeekly = "weekly" @@ -48,6 +51,15 @@ func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepo } } +// resolveBalanceThreshold returns the effective balance threshold. +// For percentage type, it computes threshold = totalRecharged * percentage / 100. +func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 { + if thresholdType == thresholdTypePercentage && totalRecharged > 0 { + return totalRecharged * threshold / 100 + } + return threshold +} + // CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction. // oldBalance is the balance before deduction, cost is the amount deducted. // Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold. @@ -73,8 +85,13 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u return } + effectiveThreshold := resolveBalanceThreshold(threshold, user.BalanceNotifyThresholdType, user.TotalRecharged) + if effectiveThreshold <= 0 { + return + } + newBalance := oldBalance - cost - if oldBalance >= threshold && newBalance < threshold { + if oldBalance >= effectiveThreshold && newBalance < effectiveThreshold { siteName := s.getSiteName(ctx) recipients := s.collectBalanceNotifyRecipients(user) go func() { @@ -83,7 +100,7 @@ func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, u slog.Error("panic in balance notification", "recover", r) } }() - s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName) + s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, effectiveThreshold, siteName) }() } } @@ -94,14 +111,14 @@ type quotaDim struct { enabled bool threshold float64 thresholdType string // "fixed" (default) or "percentage" - oldUsed float64 + currentUsed float64 limit float64 } // resolvedThreshold returns the effective threshold value. // For percentage type, it computes threshold = limit * percentage / 100. func (d quotaDim) resolvedThreshold() float64 { - if d.thresholdType == "percentage" && d.limit > 0 { + if d.thresholdType == thresholdTypePercentage && d.limit > 0 { return d.limit * d.threshold / 100 } return d.threshold @@ -150,7 +167,7 @@ func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot * } // checkQuotaDimCrossings iterates quota dimensions and sends alerts for threshold crossings. -// freshAccount has post-increment values; oldUsed is reconstructed as freshUsed - cost. +// freshAccount has post-increment values; pre-increment is reconstructed as currentUsed - cost. func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cost float64, adminEmails []string, siteName string) { for _, dim := range buildQuotaDims(freshAccount) { if !dim.enabled || dim.threshold <= 0 { @@ -160,10 +177,10 @@ func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cos if effectiveThreshold <= 0 { continue } - // dim.oldUsed is actually the post-increment value from fresh DB data; + // currentUsed is the post-increment value from fresh DB data; // reconstruct pre-increment value to detect threshold crossing. - newUsed := dim.oldUsed - oldUsed := dim.oldUsed - cost + newUsed := dim.currentUsed + oldUsed := dim.currentUsed - cost if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold { s.asyncSendQuotaAlert(adminEmails, freshAccount.Name, dim, newUsed, effectiveThreshold, siteName) } @@ -309,10 +326,9 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension) } -// buildBalanceLowEmailBody builds HTML email for balance low notification. -// Lines exceed 30 due to inline HTML template (not splittable). -func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string { - return fmt.Sprintf(` +// balanceLowEmailTemplate is the HTML template for balance low notifications. +// Format args: siteName, userName, userName, balance, threshold, threshold. +const balanceLowEmailTemplate = ` @@ -344,17 +360,11 @@ func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance
-`, siteName, userName, userName, balance, threshold, threshold) -} +` -// buildQuotaAlertEmailBody builds HTML email for account quota alert. -// Lines exceed 30 due to inline HTML template (not splittable). -func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel string, used, limit, threshold float64, siteName string) string { - limitStr := fmt.Sprintf("$%.2f", limit) - if limit <= 0 { - limitStr = "无限制 / Unlimited" - } - return fmt.Sprintf(` +// quotaAlertEmailTemplate is the HTML template for account quota alert notifications. +// Format args: siteName, accountName, dimLabel, used, limitStr, threshold. +const quotaAlertEmailTemplate = ` @@ -389,18 +399,19 @@ func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel st
-`, siteName, accountName, dimLabel, used, limitStr, threshold) +` + +// buildBalanceLowEmailBody builds HTML email for balance low notification. +func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName string) string { + return fmt.Sprintf(balanceLowEmailTemplate, siteName, userName, userName, balance, threshold, threshold) } -// parseJSONStringArray parses a JSON string array, returns nil on error. -func parseJSONStringArray(raw string) []string { - raw = strings.TrimSpace(raw) - if raw == "" || raw == "[]" { - return nil +// buildQuotaAlertEmailBody builds HTML email for account quota alert. +func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountName, dimLabel string, used, limit, threshold float64, siteName string) string { + limitStr := fmt.Sprintf("$%.2f", limit) + if limit <= 0 { + limitStr = "无限制 / Unlimited" } - var result []string - if err := json.Unmarshal([]byte(raw), &result); err != nil { - return nil - } - return result + return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountName, dimLabel, used, limitStr, threshold) } + diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 0267040d..28eb3a70 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -627,11 +627,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64) updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) - accountQuotaNotifyEmailsJSON, err := json.Marshal(settings.AccountQuotaNotifyEmails) - if err != nil { - return fmt.Errorf("marshal account quota notify emails: %w", err) - } - updates[SettingKeyAccountQuotaNotifyEmails] = string(accountQuotaNotifyEmailsJSON) + updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) err = s.settingRepo.SetMultiple(ctx, updates) if err == nil { diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 92c8dd34..c405e66b 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -87,7 +87,7 @@ @@ -346,6 +346,7 @@ import { formatCacheTokens, formatMultiplier } from '@/utils/formatters' import { formatTokenPricePerMillion } from '@/utils/usagePricing' import { getUsageServiceTierLabel } from '@/utils/usageServiceTier' import { resolveUsageRequestType } from '@/utils/usageRequestType' +import { getBillingModeLabel, getBillingModeBadgeClass } from '@/utils/billingMode' import DataTable from '@/components/common/DataTable.vue' import EmptyState from '@/components/common/EmptyState.vue' import Icon from '@/components/icons/Icon.vue' @@ -399,17 +400,6 @@ const getRequestTypeBadgeClass = (row: AdminUsageLog): string => { return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200' } -const getBillingModeLabel = (mode: string | null | undefined): string => { - if (mode === 'per_request') return t('admin.usage.billingModePerRequest') - if (mode === 'image') return t('admin.usage.billingModeImage') - return t('admin.usage.billingModeToken') -} - -const getBillingModeBadgeClass = (mode: string | null | undefined): string => { - if (mode === 'per_request') return 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200' - if (mode === 'image') return 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200' - return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200' -} const formatUserAgent = (ua: string): string => { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 056bdde0..1f9ba3e1 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1889,7 +1889,8 @@ export default { searchAccountPlaceholder: 'Search accounts...', ruleAccountsHint: 'Leave empty to match all accounts', ruleModelPricing: 'Model Pricing', - noGroupsInChannel: 'No groups selected in platform tabs above' + noGroupsInChannel: 'No groups selected in platform tabs above', + unnamed: 'Unnamed' } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 47a1f8d4..fa5d970c 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1968,7 +1968,8 @@ export default { searchAccountPlaceholder: '搜索账号...', ruleAccountsHint: '留空表示匹配所有账号', ruleModelPricing: '模型定价', - noGroupsInChannel: '上方平台标签页中未选择分组' + noGroupsInChannel: '上方平台标签页中未选择分组', + unnamed: '未命名' } }, diff --git a/frontend/src/utils/billingMode.ts b/frontend/src/utils/billingMode.ts new file mode 100644 index 00000000..152dadc4 --- /dev/null +++ b/frontend/src/utils/billingMode.ts @@ -0,0 +1,19 @@ +export const BILLING_MODE_TOKEN = 'token' +export const BILLING_MODE_PER_REQUEST = 'per_request' +export const BILLING_MODE_IMAGE = 'image' + +export function getBillingModeLabel(mode: string | null | undefined, t: (key: string) => string): string { + switch (mode) { + case BILLING_MODE_PER_REQUEST: return t('admin.usage.billingModePerRequest') + case BILLING_MODE_IMAGE: return t('admin.usage.billingModeImage') + default: return t('admin.usage.billingModeToken') + } +} + +export function getBillingModeBadgeClass(mode: string | null | undefined): string { + switch (mode) { + case BILLING_MODE_PER_REQUEST: return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' + case BILLING_MODE_IMAGE: return 'bg-pink-100 text-pink-700 dark:bg-pink-900/30 dark:text-pink-300' + default: return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' + } +} diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 5be45cb5..b714ca30 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -980,26 +980,38 @@ function clearAllRuleAccountSearchState() { showRuleAccountDropdown.value = {} } +function inferRulePlatform(groupIds: number[]): string { + const platforms = new Set() + for (const gid of groupIds) { + const group = allGroups.value.find(g => g.id === gid) + if (group) platforms.add(group.platform) + } + return platforms.size === 1 ? [...platforms][0] : '' +} + function accountStatsRulesToAPI(): AccountStatsPricingRule[] { - return form.account_stats_pricing_rules.map(rule => ({ - name: rule.name, - group_ids: rule.group_ids, - account_ids: rule.account_ids, - pricing: rule.pricing - .filter(p => p.models.length > 0) - .map(p => ({ - platform: '', - models: p.models, - billing_mode: p.billing_mode, - input_price: mTokToPerToken(p.input_price), - output_price: mTokToPerToken(p.output_price), - cache_write_price: mTokToPerToken(p.cache_write_price), - cache_read_price: mTokToPerToken(p.cache_read_price), - image_output_price: mTokToPerToken(p.image_output_price), - per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null, - intervals: formIntervalsToAPI(p.intervals || []) - })) - })) + return form.account_stats_pricing_rules.map(rule => { + const platform = inferRulePlatform(rule.group_ids) + return { + name: rule.name, + group_ids: rule.group_ids, + account_ids: rule.account_ids, + pricing: rule.pricing + .filter(p => p.models.length > 0) + .map(p => ({ + platform, + models: p.models, + billing_mode: p.billing_mode, + input_price: mTokToPerToken(p.input_price), + output_price: mTokToPerToken(p.output_price), + cache_write_price: mTokToPerToken(p.cache_write_price), + cache_read_price: mTokToPerToken(p.cache_read_price), + image_output_price: mTokToPerToken(p.image_output_price), + per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null, + intervals: formIntervalsToAPI(p.intervals || []) + })) + } + }) } // ── Form ↔ API conversion ── @@ -1329,7 +1341,7 @@ async function handleSubmit() { const intervalErr = validateIntervals(entry.intervals) if (intervalErr) { const platformLabel = t('admin.groups.platforms.' + section.platform, section.platform) - const modelLabel = entry.models.join(', ') || '未命名' + const modelLabel = entry.models.join(', ') || t('admin.channels.form.unnamed') appStore.showError(`${platformLabel} - ${modelLabel}: ${intervalErr}`) activeTab.value = section.platform return diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index c687f0db..d2e17440 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2804,7 +2804,7 @@ import type { WebSearchProviderConfig, WebSearchTestResult, } from '@/api/admin/settings' -import type { AdminGroup, Proxy } from '@/types' +import type { AdminGroup, Proxy, NotifyEmailEntry } from '@/types' import type { ProviderInstance } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' import Icon from '@/components/icons/Icon.vue' @@ -3028,7 +3028,7 @@ const form = reactive({ balance_low_notify_enabled: false, balance_low_notify_threshold: 0, account_quota_notify_enabled: false, - account_quota_notify_emails: [] as { email: string; disabled: boolean; verified: boolean }[] + account_quota_notify_emails: [] as NotifyEmailEntry[] }) // Proxies for web search emulation ProxySelector diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue index 2ec0fea5..08298d89 100644 --- a/frontend/src/views/user/UsageView.vue +++ b/frontend/src/views/user/UsageView.vue @@ -192,7 +192,7 @@ @@ -524,6 +524,7 @@ import { formatCacheTokens, formatMultiplier } from '@/utils/formatters' import { formatTokenPricePerMillion } from '@/utils/usagePricing' import { getUsageServiceTierLabel } from '@/utils/usageServiceTier' import { resolveUsageRequestType } from '@/utils/usageRequestType' +import { getBillingModeLabel, getBillingModeBadgeClass } from '@/utils/billingMode' const { t } = useI18n() const appStore = useAppStore() @@ -644,17 +645,6 @@ const getRequestTypeBadgeClass = (log: UsageLog): string => { return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200' } -const getBillingModeLabel = (mode: string | null | undefined): string => { - if (mode === 'per_request') return t('admin.usage.billingModePerRequest') - if (mode === 'image') return t('admin.usage.billingModeImage') - return t('admin.usage.billingModeToken') -} - -const getBillingModeBadgeClass = (mode: string | null | undefined): string => { - if (mode === 'per_request') return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-200' - if (mode === 'image') return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-200' - return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300' -} const getRequestTypeExportText = (log: UsageLog): string => { const requestType = resolveUsageRequestType(log) @@ -866,7 +856,7 @@ const exportToCSV = async () => { formatReasoningEffort(log.reasoning_effort), log.inbound_endpoint || '', getRequestTypeExportText(log), - getBillingModeLabel(log.billing_mode), + getBillingModeLabel(log.billing_mode, t), log.input_tokens, log.output_tokens, log.cache_read_tokens, From b7fb2e43871887e37d3427376b832091b8cda8d2 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 13:59:35 +0800 Subject: [PATCH 43/88] fix: audit fixes for websearch, notifications, and channel pricing P0: fix wildcard matching test assertion (config order, not longest prefix) P0: add TotalRecharged to auth cache snapshot (v5) for percentage threshold P1: move pricing rules into per-platform sections in ChannelsView P1: populate account name cache when editing existing channel rules P1: sanitize email subject headers to prevent SMTP injection P1: make Redis INCR+EXPIRE idempotent for rate limiting P1: deep copy FeaturesConfig in Channel.Clone() P2: clean up stale email="" placeholder comments P2: replace log.Printf with slog in email_service.go --- backend/cmd/server/VERSION | 2 +- .../handler/dto/notify_email_entry.go | 2 +- backend/internal/handler/user_handler.go | 2 +- backend/internal/repository/email_cache.go | 38 +++- .../service/account_stats_pricing_test.go | 4 +- .../internal/service/api_key_auth_cache.go | 1 + .../service/api_key_auth_cache_impl.go | 4 +- .../service/balance_notify_service.go | 9 +- backend/internal/service/channel.go | 16 ++ backend/internal/service/email_service.go | 12 +- .../internal/service/notify_email_entry.go | 2 +- backend/internal/service/user_service.go | 103 ++++++--- frontend/src/views/admin/ChannelsView.vue | 196 ++++++++++++------ 13 files changed, 273 insertions(+), 118 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 061bed0b..b2581b19 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.11 +0.1.110.20 diff --git a/backend/internal/handler/dto/notify_email_entry.go b/backend/internal/handler/dto/notify_email_entry.go index 180f8b25..78641005 100644 --- a/backend/internal/handler/dto/notify_email_entry.go +++ b/backend/internal/handler/dto/notify_email_entry.go @@ -3,7 +3,7 @@ package dto import "github.com/Wei-Shaw/sub2api/internal/service" // NotifyEmailEntry represents a notification email with enable/disable and verification state. -// Email="" is a placeholder for the "primary email" (user's registration email or first admin email). +// All emails are user-managed; maximum 3 entries per user. type NotifyEmailEntry struct { Email string `json:"email"` Disabled bool `json:"disabled"` diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 9e0a243a..2535ea5e 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -217,7 +217,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state type ToggleNotifyEmailRequest struct { - Email string `json:"email"` // empty string for primary email placeholder + Email string `json:"email" binding:"required,email"` Disabled bool `json:"disabled"` } diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index 63552ab0..ed903e0d 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -3,6 +3,7 @@ package repository import ( "context" "encoding/json" + "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -10,10 +11,11 @@ import ( ) const ( - verifyCodeKeyPrefix = "verify_code:" - notifyVerifyKeyPrefix = "notify_verify:" - passwordResetKeyPrefix = "password_reset:" - passwordResetSentAtKeyPrefix = "password_reset_sent:" + verifyCodeKeyPrefix = "verify_code:" + notifyVerifyKeyPrefix = "notify_verify:" + passwordResetKeyPrefix = "password_reset:" + passwordResetSentAtKeyPrefix = "password_reset_sent:" + notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" ) // verifyCodeKey generates the Redis key for email verification code. @@ -141,3 +143,31 @@ func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) e key := notifyVerifyKey(email) return c.rdb.Del(ctx, key).Err() } + +// User-level rate limiting for notify email verification codes + +func notifyCodeUserRateKey(userID int64) string { + return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID) +} + +func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) { + key := notifyCodeUserRateKey(userID) + count, err := c.rdb.Incr(ctx, key).Result() + if err != nil { + return 0, err + } + // Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE. + if err := c.rdb.Expire(ctx, key, window).Err(); err != nil { + return count, fmt.Errorf("expire notify code rate key: %w", err) + } + return count, nil +} + +func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) { + key := notifyCodeUserRateKey(userID) + count, err := c.rdb.Get(ctx, key).Int64() + if err != nil { + return 0, err + } + return count, nil +} diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index bf9da978..23409d5e 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -145,14 +145,14 @@ func TestFindPricingForModel(t *testing.T) { wantNil: true, }, { - name: "longer wildcard prefix wins over shorter", + name: "wildcard matches by config order (first match wins)", list: []ChannelModelPricing{ {ID: 10, Models: []string{"claude-*"}}, {ID: 11, Models: []string{"claude-opus-*"}}, }, platform: "", model: "claude-opus-4", - wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars) + wantID: 10, // config order: "claude-*" is first and matches, so it wins }, { name: "shorter wildcard used when longer does not match", diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 60cb6233..b1660ea7 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -42,6 +42,7 @@ type APIKeyAuthUserSnapshot struct { BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` + TotalRecharged float64 `json:"total_recharged"` } // APIKeyAuthGroupSnapshot 分组快照 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 711090c2..25c6331a 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -13,7 +13,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 4 // v4: added balance notification fields to UserSnapshot +const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold type apiKeyAuthCacheConfig struct { l1Size int @@ -230,6 +230,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType, BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, + TotalRecharged: apiKey.User.TotalRecharged, }, } if apiKey.Group != nil { @@ -291,6 +292,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType, BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, + TotalRecharged: snapshot.User.TotalRecharged, }, } if snapshot.Group != nil { diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 1e4d8ff6..14aa6766 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -309,7 +309,7 @@ func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userNam if displayName == "" { displayName = userEmail } - subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", siteName) + subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName)) body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName)) s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance) } @@ -321,11 +321,16 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun dimLabel = dimension } - subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", siteName, accountName) + subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName)) body := s.buildQuotaAlertEmailBody(html.EscapeString(accountName), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName)) s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension) } +// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection. +func sanitizeEmailHeader(s string) string { + return strings.NewReplacer("\r", "", "\n", "").Replace(s) +} + // balanceLowEmailTemplate is the HTML template for balance low notifications. // Format args: siteName, userName, userName, balance, threshold, threshold. const balanceLowEmailTemplate = ` diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 3867f2a0..b034fda0 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -196,6 +196,9 @@ func (c *Channel) Clone() *Channel { cp.ModelMapping[platform] = inner } } + if c.FeaturesConfig != nil { + cp.FeaturesConfig = deepCopyFeaturesConfig(c.FeaturesConfig) + } if c.AccountStatsPricingRules != nil { cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules)) for i, rule := range c.AccountStatsPricingRules { @@ -219,6 +222,19 @@ func (c *Channel) Clone() *Channel { return &cp } +// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution. +func deepCopyFeaturesConfig(src map[string]any) map[string]any { + dst := make(map[string]any, len(src)) + for k, v := range src { + if inner, ok := v.(map[string]any); ok { + dst[k] = deepCopyFeaturesConfig(inner) + } else { + dst[k] = v + } + } + return dst +} + // ValidateIntervals 校验区间列表的合法性。 // 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens; // 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义); diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 61090776..3eade83e 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -7,7 +7,7 @@ import ( "crypto/tls" "encoding/hex" "fmt" - "log" + "log/slog" "math/big" "net/smtp" "net/url" @@ -292,7 +292,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { data.Attempts++ if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { - log.Printf("[Email] Failed to update verification attempt count: %v", err) + slog.Error("failed to update verification attempt count", "email", email, "error", err) } if data.Attempts >= maxVerifyCodeAttempts { return ErrVerifyCodeMaxAttempts @@ -302,7 +302,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证成功,删除验证码 if err := s.cache.DeleteVerificationCode(ctx, email); err != nil { - log.Printf("[Email] Failed to delete verification code after success: %v", err) + slog.Error("failed to delete verification code after success", "email", email, "error", err) } return nil } @@ -452,7 +452,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error { // Check email cooldown to prevent email bombing if s.cache.IsPasswordResetEmailInCooldown(ctx, email) { - log.Printf("[Email] Password reset email skipped (cooldown): %s", email) + slog.Info("password reset email skipped due to cooldown", "email", email) return nil // Silent success to prevent revealing cooldown to attackers } @@ -463,7 +463,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e // Set cooldown marker (Redis TTL handles expiration) if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil { - log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err) + slog.Error("failed to set password reset cooldown", "email", email, "error", err) } return nil @@ -493,7 +493,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok // Delete after verification (one-time use) if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil { - log.Printf("[Email] Failed to delete password reset token after consumption: %v", err) + slog.Error("failed to delete password reset token after consumption", "email", email, "error", err) } return nil } diff --git a/backend/internal/service/notify_email_entry.go b/backend/internal/service/notify_email_entry.go index c0e739f4..d181200b 100644 --- a/backend/internal/service/notify_email_entry.go +++ b/backend/internal/service/notify_email_entry.go @@ -6,7 +6,7 @@ import ( ) // NotifyEmailEntry represents a notification email with enable/disable and verification state. -// Email="" is a placeholder for the "primary email" (user's registration email or first admin email). +// All emails are user-managed; maximum 3 entries per user. type NotifyEmailEntry struct { Email string `json:"email"` Disabled bool `json:"disabled"` diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index bcb21c1d..3baee81d 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -4,7 +4,7 @@ import ( "context" "crypto/subtle" "fmt" - "log" + "log/slog" "strings" "time" @@ -13,12 +13,19 @@ import ( ) var ( - ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") - ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") - ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") + ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") + ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") + ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") ) -const maxNotifyEmails = 3 // Total limit: primary (email="") + up to 2 extra +const ( + maxNotifyEmails = 3 // Maximum number of notification emails per user + + // User-level rate limiting for notify email verification codes + notifyCodeUserRateLimit = 5 + notifyCodeUserRateWindow = 10 * time.Minute +) // UserListFilters contains all filter options for listing users type UserListFilters struct { @@ -220,7 +227,7 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil { - log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + slog.Error("invalidate user balance cache failed", "user_id", userID, "error", err) } }() } @@ -270,21 +277,44 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error { // SendNotifyEmailCode sends a verification code to the extra notification email. func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error { - // Check cooldown + if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil { + return err + } + + code, err := emailService.GenerateVerifyCode() + if err != nil { + return fmt.Errorf("generate code: %w", err) + } + + if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil { + return err + } + + // Increment user-level counter after successful save + if _, err := cache.IncrNotifyCodeUserRate(ctx, userID, notifyCodeUserRateWindow); err != nil { + slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err) + } + + return s.sendNotifyVerifyEmail(ctx, emailService, email, code) +} + +// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit. +func checkNotifyCodeRateLimit(ctx context.Context, cache EmailCache, userID int64, email string) error { existing, err := cache.GetNotifyVerifyCode(ctx, email) if err == nil && existing != nil { if time.Since(existing.CreatedAt) < verifyCodeCooldown { return ErrVerifyCodeTooFrequent } } - - // Generate code - code, err := emailService.GenerateVerifyCode() - if err != nil { - return fmt.Errorf("generate code: %w", err) + count, err := cache.GetNotifyCodeUserRate(ctx, userID) + if err == nil && count >= notifyCodeUserRateLimit { + return ErrNotifyCodeUserRateLimit } + return nil +} - // Save to cache +// saveNotifyVerifyCode saves the verification code to cache. +func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code string) error { data := &VerificationCodeData{ Code: code, Attempts: 0, @@ -293,16 +323,17 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { return fmt.Errorf("save verify code: %w", err) } + return nil +} - // Get site name +// sendNotifyVerifyEmail builds and sends the verification email. +func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error { siteName := "Sub2API" if s.settingRepo != nil { if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" { siteName = name } } - - // Build and send email subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName) body := buildNotifyVerifyEmailBody(code, siteName) return emailService.SendEmail(ctx, email, subject, body) @@ -310,7 +341,15 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema // VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails. func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error { - // Verify code + if err := verifyNotifyCode(ctx, cache, email, code); err != nil { + return err + } + _ = cache.DeleteNotifyVerifyCode(ctx, email) + return s.addOrVerifyNotifyEmail(ctx, userID, email) +} + +// verifyNotifyCode validates the verification code against the cached data. +func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) error { data, err := cache.GetNotifyVerifyCode(ctx, email) if err != nil || data == nil { return ErrInvalidVerifyCode @@ -326,17 +365,18 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, } return ErrInvalidVerifyCode } + return nil +} - // Delete code after verification - _ = cache.DeleteNotifyVerifyCode(ctx, email) - - // Add to user's extra emails +// addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified. +// Note: concurrent calls for the same user could race on the read-modify-write of +// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing +// simultaneously), and the worst case is a duplicate entry which is harmless. +func (s *UserService) addOrVerifyNotifyEmail(ctx context.Context, userID int64, email string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return err } - - // Check if already exists — if unverified, mark as verified for i, e := range user.BalanceNotifyExtraEmails { if strings.EqualFold(e.Email, email) { if !e.Verified { @@ -346,12 +386,9 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, return nil // Already verified } } - - // Check limit if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails { return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails)) } - user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{ Email: email, Disabled: false, @@ -399,10 +436,9 @@ func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email return s.userRepo.Update(ctx, user) } -// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification. -func buildNotifyVerifyEmailBody(code, siteName string) string { - return fmt.Sprintf(` - +// notifyVerifyEmailTemplate is the HTML template for notify email verification. +// Format args: siteName, code. +const notifyVerifyEmailTemplate = ` @@ -439,6 +475,9 @@ func buildNotifyVerifyEmailBody(code, siteName string) string {

- -`, siteName, code) +` + +// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification. +func buildNotifyVerifyEmailBody(code, siteName string) string { + return fmt.Sprintf(notifyVerifyEmailTemplate, siteName, code) } diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index b714ca30..2ca1141d 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -421,7 +421,7 @@
@@ -524,7 +524,7 @@
-
@@ -538,7 +538,7 @@ :entry="entry" :platform="section.platform" @update="rule.pricing.splice(pIdx, 1, $event)" - @remove="removeRulePricingEntry(ruleIndex, pIdx)" + @remove="removeRulePricingEntry(sIdx, ruleIndex, pIdx)" />
@@ -625,6 +625,14 @@ async function loadWebSearchGlobalState() { } } +// ── Form-level pricing rule type (per-platform) ── +interface FormPricingRule { + name: string + group_ids: number[] + account_ids: number[] + pricing: PricingFormEntry[] +} + // ── Platform Section type ── interface PlatformSection { platform: GroupPlatform @@ -634,6 +642,7 @@ interface PlatformSection { model_mapping: Record model_pricing: PricingFormEntry[] web_search_emulation: boolean + account_stats_pricing_rules: FormPricingRule[] } // ── Table columns ── @@ -703,12 +712,6 @@ const form = reactive({ billing_model_source: 'channel_mapped' as string, platforms: [] as PlatformSection[], apply_pricing_to_account_stats: false, - account_stats_pricing_rules: [] as Array<{ - name: string - group_ids: number[] - account_ids: number[] - pricing: PricingFormEntry[] - }> }) let abortController: AbortController | null = null @@ -754,6 +757,7 @@ function addPlatformSection(platform: GroupPlatform) { model_mapping: {}, model_pricing: [], web_search_emulation: false, + account_stats_pricing_rules: [], }) } @@ -867,8 +871,8 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) { } // ── Account Stats Pricing helpers ── -function addAccountStatsRule() { - form.account_stats_pricing_rules.push({ +function addAccountStatsRule(sectionIdx: number) { + form.platforms[sectionIdx].account_stats_pricing_rules.push({ name: '', group_ids: [], account_ids: [], @@ -876,8 +880,8 @@ function addAccountStatsRule() { }) } -function addRulePricingEntry(ruleIndex: number) { - form.account_stats_pricing_rules[ruleIndex].pricing.push({ +function addRulePricingEntry(sectionIdx: number, ruleIndex: number) { + form.platforms[sectionIdx].account_stats_pricing_rules[ruleIndex].pricing.push({ models: [], billing_mode: 'token', input_price: null, @@ -890,15 +894,15 @@ function addRulePricingEntry(ruleIndex: number) { }) } -function removeAccountStatsRule(ruleIndex: number) { - form.account_stats_pricing_rules.splice(ruleIndex, 1) +function removeAccountStatsRule(sectionIdx: number, ruleIndex: number) { + form.platforms[sectionIdx].account_stats_pricing_rules.splice(ruleIndex, 1) // Clear all search state since indices shift after removal ruleAccountSearchRunner.clearAll() clearAllRuleAccountSearchState() } -function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) { - form.account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1) +function removeRulePricingEntry(sectionIdx: number, ruleIndex: number, pricingIndex: number) { + form.platforms[sectionIdx].account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1) } function getGroupNameById(groupId: number): string { @@ -980,38 +984,33 @@ function clearAllRuleAccountSearchState() { showRuleAccountDropdown.value = {} } -function inferRulePlatform(groupIds: number[]): string { - const platforms = new Set() - for (const gid of groupIds) { - const group = allGroups.value.find(g => g.id === gid) - if (group) platforms.add(group.platform) - } - return platforms.size === 1 ? [...platforms][0] : '' -} - function accountStatsRulesToAPI(): AccountStatsPricingRule[] { - return form.account_stats_pricing_rules.map(rule => { - const platform = inferRulePlatform(rule.group_ids) - return { - name: rule.name, - group_ids: rule.group_ids, - account_ids: rule.account_ids, - pricing: rule.pricing - .filter(p => p.models.length > 0) - .map(p => ({ - platform, - models: p.models, - billing_mode: p.billing_mode, - input_price: mTokToPerToken(p.input_price), - output_price: mTokToPerToken(p.output_price), - cache_write_price: mTokToPerToken(p.cache_write_price), - cache_read_price: mTokToPerToken(p.cache_read_price), - image_output_price: mTokToPerToken(p.image_output_price), - per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null, - intervals: formIntervalsToAPI(p.intervals || []) - })) + const rules: AccountStatsPricingRule[] = [] + for (const section of form.platforms) { + if (!section.enabled) continue + for (const rule of section.account_stats_pricing_rules) { + rules.push({ + name: rule.name, + group_ids: rule.group_ids, + account_ids: rule.account_ids, + pricing: rule.pricing + .filter(p => p.models.length > 0) + .map(p => ({ + platform: section.platform, + models: p.models, + billing_mode: p.billing_mode, + input_price: mTokToPerToken(p.input_price), + output_price: mTokToPerToken(p.output_price), + cache_write_price: mTokToPerToken(p.cache_write_price), + cache_read_price: mTokToPerToken(p.cache_read_price), + image_output_price: mTokToPerToken(p.image_output_price), + per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null, + intervals: formIntervalsToAPI(p.intervals || []) + })) + }) } - }) + } + return rules } // ── Form ↔ API conversion ── @@ -1120,6 +1119,7 @@ function apiToForm(channel: Channel): PlatformSection[] { model_mapping: { ...mapping }, model_pricing: pricing, web_search_emulation: webSearchEnabled, + account_stats_pricing_rules: [], }) } @@ -1213,7 +1213,6 @@ function resetForm() { form.billing_model_source = 'channel_mapped' form.platforms = [] form.apply_pricing_to_account_stats = false - form.account_stats_pricing_rules = [] activeTab.value = 'basic' ruleAccountSearchRunner.clearAll() clearAllRuleAccountSearchState() @@ -1235,28 +1234,91 @@ async function openEditDialog(channel: Channel) { form.restrict_models = channel.restrict_models || false form.billing_model_source = channel.billing_model_source || 'channel_mapped' form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false - form.account_stats_pricing_rules = (channel.account_stats_pricing_rules || []).map(rule => ({ - name: rule.name || '', - group_ids: [...(rule.group_ids || [])], - account_ids: [...(rule.account_ids || [])], - pricing: (rule.pricing || []).map(p => ({ - models: [...(p.models || [])], - billing_mode: p.billing_mode, - input_price: perTokenToMTok(p.input_price), - output_price: perTokenToMTok(p.output_price), - cache_write_price: perTokenToMTok(p.cache_write_price), - cache_read_price: perTokenToMTok(p.cache_read_price), - image_output_price: perTokenToMTok(p.image_output_price), - per_request_price: p.per_request_price, - intervals: apiIntervalsToForm(p.intervals || []) - } as PricingFormEntry)) - })) // Must load groups first so apiToForm can map groupID → platform await Promise.all([loadGroups(), loadAllChannelsForConflict()]) form.platforms = apiToForm(channel) + + // Distribute channel-level rules into per-platform sections + distributeRulesToPlatforms(channel.account_stats_pricing_rules || []) + + // Populate ruleAccountNameCache for existing rule accounts + await populateRuleAccountNameCache() + showDialog.value = true } +/** Distribute flat channel-level rules into the matching platform section based on group_ids */ +function distributeRulesToPlatforms(apiRules: AccountStatsPricingRule[]) { + // Build groupID → platform lookup + const groupPlatformMap = new Map() + for (const g of allGroups.value) { + groupPlatformMap.set(g.id, g.platform) + } + + for (const apiRule of apiRules) { + // Infer platform from group_ids + const platforms = new Set() + for (const gid of apiRule.group_ids || []) { + const p = groupPlatformMap.get(gid) + if (p) platforms.add(p) + } + // If pricing has a platform field, use that as fallback + if (platforms.size === 0 && apiRule.pricing?.length > 0) { + const p = apiRule.pricing[0].platform as GroupPlatform | undefined + if (p) platforms.add(p) + } + const targetPlatform = platforms.size >= 1 ? [...platforms][0] : null + if (!targetPlatform) continue + + const section = form.platforms.find(s => s.platform === targetPlatform) + if (!section) continue + + const formRule: FormPricingRule = { + name: apiRule.name || '', + group_ids: [...(apiRule.group_ids || [])], + account_ids: [...(apiRule.account_ids || [])], + pricing: (apiRule.pricing || []).map(p => ({ + models: [...(p.models || [])], + billing_mode: p.billing_mode, + input_price: perTokenToMTok(p.input_price), + output_price: perTokenToMTok(p.output_price), + cache_write_price: perTokenToMTok(p.cache_write_price), + cache_read_price: perTokenToMTok(p.cache_read_price), + image_output_price: perTokenToMTok(p.image_output_price), + per_request_price: p.per_request_price, + intervals: apiIntervalsToForm(p.intervals || []) + } as PricingFormEntry)) + } + section.account_stats_pricing_rules.push(formRule) + } +} + +/** Populate ruleAccountNameCache by fetching account details for all account_ids in rules */ +async function populateRuleAccountNameCache() { + const allAccountIds = new Set() + for (const section of form.platforms) { + for (const rule of section.account_stats_pricing_rules) { + for (const id of rule.account_ids) { + allAccountIds.add(id) + } + } + } + if (allAccountIds.size === 0) return + + // Fetch account details in parallel (batch of individual getById calls) + const ids = [...allAccountIds] + const results = await Promise.allSettled( + ids.map(id => adminAPI.accounts.getById(id)) + ) + for (let i = 0; i < ids.length; i++) { + const result = results[i] + if (result.status === 'fulfilled') { + ruleAccountNameCache.value[ids[i]] = result.value.name + } + // If rejected, the cache won't have the name, so it'll show "#ID" which is acceptable + } +} + function closeDialog() { showDialog.value = false editingChannel.value = null From b1875f0b826b48d941df751061f4c92071498ed9 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 14:21:37 +0800 Subject: [PATCH 44/88] fix: round 3 audit fixes - SMTP header sanitization and goroutine safety - Move sanitizeEmailHeader to SendEmailWithConfig entry point, covering all email senders (verify code, password reset, ops alerts, notifications) - Add panic recovery to UpdateBalance goroutine - Fix stale comment in getAccountQuotaNotifyEmails (email="" no longer used) - Log error instead of silently discarding verifyNotifyCode cache update failure --- backend/cmd/server/VERSION | 2 +- backend/internal/service/balance_notify_service.go | 2 +- backend/internal/service/email_service.go | 8 ++++++-- backend/internal/service/user_service.go | 9 ++++++++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index b2581b19..49e7ca57 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.20 +0.1.110.21 diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 14aa6766..a392a13e 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -225,7 +225,7 @@ func (s *BalanceNotifyService) isAccountQuotaNotifyEnabled(ctx context.Context) } // getAccountQuotaNotifyEmails reads admin notification emails from settings, -// filtering out disabled entries. Entries with email="" are resolved to the first admin's email. +// filtering out disabled and unverified entries. func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string { raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails) if err != nil || strings.TrimSpace(raw) == "" || raw == "[]" { diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 3eade83e..50d324f2 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -153,9 +153,13 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) // SendEmailWithConfig 使用指定配置发送邮件 func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error { - from := config.From + // Sanitize all SMTP header fields to prevent header injection (CR/LF removal). + to = sanitizeEmailHeader(to) + subject = sanitizeEmailHeader(subject) + + from := sanitizeEmailHeader(config.From) if config.FromName != "" { - from = fmt.Sprintf("%s <%s>", config.FromName, config.From) + from = fmt.Sprintf("%s <%s>", sanitizeEmailHeader(config.FromName), sanitizeEmailHeader(config.From)) } msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s", diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3baee81d..0da73762 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -224,6 +224,11 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl } if s.billingCache != nil { go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("panic in balance cache invalidation", "user_id", userID, "recover", r) + } + }() cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil { @@ -359,7 +364,9 @@ func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) } if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { data.Attempts++ - _ = cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL) + if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { + slog.Error("failed to update notify verify code attempts", "email", email, "error", err) + } if data.Attempts >= maxVerifyCodeAttempts { return ErrVerifyCodeMaxAttempts } From 48e8efe3e8489a34dea34528caefcb4a64fec4cc Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 15:13:56 +0800 Subject: [PATCH 45/88] fix(frontend): hide quota notify toggle when global setting is disabled QuotaLimitCard now requires quotaNotifyGlobalEnabled prop to control visibility of QuotaNotifyToggle components. When the global account quota notification is disabled in admin settings, per-account threshold toggles are hidden in both Edit and Create account modals. --- backend/cmd/server/VERSION | 2 +- .../components/account/CreateAccountModal.vue | 53 +++++++++++++++++-- .../components/account/EditAccountModal.vue | 9 +++- .../src/components/account/QuotaLimitCard.vue | 8 +-- 4 files changed, 62 insertions(+), 10 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 49e7ca57..d6418142 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.21 +0.1.110.23 diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index e83e061e..e3d2d19a 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1477,10 +1477,47 @@ - -
+ +
-

{{ t('admin.accounts.quotaLimit') }}

+

{{ t('admin.accounts.quotaControl.title') }}

+

+ {{ t('admin.accounts.quotaControl.hint') }} +

+
+ +
+ + +
+
+

{{ t('admin.accounts.quotaControl.title') }}

{{ t('admin.accounts.quotaLimitHint') }}

@@ -1489,6 +1526,7 @@ :totalLimit="editQuotaLimit" :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" + :quotaNotifyGlobalEnabled="quotaNotifyGlobalEnabled" :dailyResetMode="editDailyResetMode" :dailyResetHour="editDailyResetHour" :weeklyResetMode="editWeeklyResetMode" @@ -1823,7 +1861,7 @@
- +
{ webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0 }).catch(() => { webSearchGlobalEnabled.value = false }) + +adminAPI.settings.getSettings().then(settings => { + quotaNotifyGlobalEnabled.value = settings.account_quota_notify_enabled === true +}).catch(() => { quotaNotifyGlobalEnabled.value = false }) const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 74e5fc1f..89f1fb18 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1190,6 +1190,7 @@ :weeklyResetDay="editWeeklyResetDay" :weeklyResetHour="editWeeklyResetHour" :resetTimezone="editResetTimezone" + :quotaNotifyGlobalEnabled="quotaNotifyGlobalEnabled" :quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled" :quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold" :quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType" @@ -1240,6 +1241,7 @@ :weeklyResetDay="editWeeklyResetDay" :weeklyResetHour="editWeeklyResetHour" :resetTimezone="editResetTimezone" + :quotaNotifyGlobalEnabled="quotaNotifyGlobalEnabled" :quotaNotifyDailyEnabled="editQuotaNotifyDailyEnabled" :quotaNotifyDailyThreshold="editQuotaNotifyDailyThreshold" :quotaNotifyDailyThresholdType="editQuotaNotifyDailyThresholdType" @@ -1991,11 +1993,16 @@ const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) const webSearchEmulationMode = ref('default') const webSearchGlobalEnabled = ref(false) +const quotaNotifyGlobalEnabled = ref(false) -// Load web search global state once +// Load global feature states once adminAPI.settings.getWebSearchEmulationConfig().then(cfg => { webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0 }).catch(() => { webSearchGlobalEnabled.value = false }) + +adminAPI.settings.getSettings().then(settings => { + quotaNotifyGlobalEnabled.value = settings.account_quota_notify_enabled === true +}).catch(() => { quotaNotifyGlobalEnabled.value = false }) const editQuotaLimit = ref(null) const editQuotaDailyLimit = ref(null) const editQuotaWeeklyLimit = ref(null) diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 64bdb08a..38ab6479 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -15,6 +15,7 @@ const props = withDefaults(defineProps<{ weeklyResetDay: number | null weeklyResetHour: number | null resetTimezone: string | null + quotaNotifyGlobalEnabled?: boolean quotaNotifyDailyEnabled?: boolean | null quotaNotifyDailyThreshold?: number | null quotaNotifyDailyThresholdType?: string | null @@ -25,6 +26,7 @@ const props = withDefaults(defineProps<{ quotaNotifyTotalThreshold?: number | null quotaNotifyTotalThresholdType?: string | null }>(), { + quotaNotifyGlobalEnabled: false, quotaNotifyDailyEnabled: null, quotaNotifyDailyThreshold: null, quotaNotifyDailyThresholdType: null, @@ -234,7 +236,7 @@ const onWeeklyModeChange = (e: Event) => {

{

{

{{ t('admin.accounts.quotaTotalLimitHint') }}

Date: Mon, 13 Apr 2026 15:20:00 +0800 Subject: [PATCH 46/88] fix(frontend): simplify websearch select labels and reduce width MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - "默认(跟随渠道)" → "默认", "Default (follow channel)" → "Default" - Move "follows channel config" info to description text - Reduce select width from w-32 to w-24 in both Edit and Create modals --- frontend/src/components/account/CreateAccountModal.vue | 2 +- frontend/src/components/account/EditAccountModal.vue | 2 +- frontend/src/i18n/locales/en.ts | 4 ++-- frontend/src/i18n/locales/zh.ts | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index e3d2d19a..a1496fa8 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2375,7 +2375,7 @@ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}

- diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 89f1fb18..613738d2 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1161,7 +1161,7 @@ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}

- diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 1f9ba3e1..5c593946 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2383,8 +2383,8 @@ export default { 'Only applies to Anthropic API Key accounts. When enabled, messages/count_tokens are forwarded in passthrough mode with auth replacement only, while billing/concurrency/audit and safety filtering are preserved. Disable to roll back immediately.', webSearchEmulation: 'Web Search Emulation', webSearchEmulationDesc: - 'Enable web search emulation for this API Key account. When a pure web_search request is detected, the gateway calls a third-party search API and constructs the response locally.', - webSearchDefault: 'Default (follow channel)', + 'Enable web search emulation for this API Key account. When a pure web_search request is detected, the gateway calls a third-party search API and constructs the response locally. Default follows channel config.', + webSearchDefault: 'Default', webSearchEnabled: 'Enabled', webSearchDisabled: 'Disabled', }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index fa5d970c..141193c1 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2530,8 +2530,8 @@ export default { '仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。', webSearchEmulation: 'Web Search 模拟', webSearchEmulationDesc: - '为该 API Key 账号启用 web search 模拟。客户端发送纯 web_search 请求时,由网关调用第三方搜索 API 并构造响应返回。', - webSearchDefault: '默认(跟随渠道)', + '为该 API Key 账号启用 web search 模拟。客户端发送纯 web_search 请求时,由网关调用第三方搜索 API 并构造响应返回。默认跟随渠道配置。', + webSearchDefault: '默认', webSearchEnabled: '开启', webSearchDisabled: '关闭', }, From 42f8ef331576b43958cbf39def767a55857f3b32 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 15:30:06 +0800 Subject: [PATCH 47/88] fix: add missing AccountQuotaNotifyEnabled to admin settings API The field was present in SystemSettings response DTO and service layer but missing from: - UpdateSettingsRequest (admin handler) - saves were silently ignored - GET/PUT response mapping in admin handler - UpdateSettingsRequest (non-admin dto) This caused the toggle to always revert to off after saving. --- backend/cmd/server/VERSION | 2 +- backend/internal/handler/admin/setting_handler.go | 9 +++++++++ backend/internal/handler/dto/settings.go | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index d6418142..8775ce89 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.23 +0.1.110.24 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index fe46e821..58081273 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -178,6 +178,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, + AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), PaymentEnabled: paymentCfg.Enabled, PaymentMinAmount: paymentCfg.MinAmount, @@ -311,6 +312,7 @@ type UpdateSettingsRequest struct { // Balance low notification BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` + AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"` AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"` // Payment configuration (integrated into settings, full replace) @@ -902,6 +904,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.BalanceLowNotifyThreshold }(), + AccountQuotaNotifyEnabled: func() bool { + if req.AccountQuotaNotifyEnabled != nil { + return *req.AccountQuotaNotifyEnabled + } + return previousSettings.AccountQuotaNotifyEnabled + }(), AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry { if req.AccountQuotaNotifyEmails != nil { return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails) @@ -1056,6 +1064,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableCCHSigning: updatedSettings.EnableCCHSigning, BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, + AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), PaymentEnabled: updatedPaymentCfg.Enabled, PaymentMinAmount: updatedPaymentCfg.MinAmount, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 545458c8..6c99208f 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -152,6 +152,7 @@ type SystemSettings struct { // Balance low notification BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` } From 98c9d51791fe5905247268d183abc9d069ff98d3 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 16:45:10 +0800 Subject: [PATCH 48/88] fix: correct account stats pricing priority order MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Priority was wrong: - Before: custom rules → LiteLLM (when ApplyPricingToAccountStats) → nil - After: custom rules → totalCost (when ApplyPricingToAccountStats) → LiteLLM → nil When ApplyPricingToAccountStats is enabled, use the request's actual client billing cost (before multiplier) as account_stats_cost, instead of recalculating from LiteLLM per-token prices which produced incorrect values for per-request billing mode. LiteLLM model pricing is now the final fallback (priority 3), used only when neither custom rules nor ApplyPricingToAccountStats apply. --- backend/cmd/server/VERSION | 2 +- .../internal/service/account_stats_pricing.go | 20 +++++++++++++++---- backend/internal/service/gateway_service.go | 1 + .../service/openai_gateway_service.go | 2 +- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 8775ce89..c74c5ae7 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.24 +0.1.110.27 diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index cbe9c76c..8251dede 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -11,10 +11,12 @@ import ( // // 优先级(先命中为准): // 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关) -// 2. ApplyPricingToAccountStats 启用时,用模型定价文件(LiteLLM)中上游模型的标准价格计算 -// 3. nil → 走默认公式 +// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost) +// 3. 模型定价文件(LiteLLM)中上游模型的默认价格 +// 4. nil → 走默认公式(total_cost × account_rate_multiplier) // // upstreamModel 是最终发往上游的模型 ID。 +// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。 func resolveAccountStatsCost( ctx context.Context, channelService *ChannelService, @@ -24,6 +26,7 @@ func resolveAccountStatsCost( upstreamModel string, tokens UsageTokens, requestCount int, + totalCost float64, ) *float64 { if channelService == nil || upstreamModel == "" { return nil @@ -40,8 +43,17 @@ func resolveAccountStatsCost( return cost } - // 优先级 2:模型定价文件(LiteLLM/fallback)中上游模型的标准价格 - if channel.ApplyPricingToAccountStats && billingService != nil { + // 优先级 2:渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前) + if channel.ApplyPricingToAccountStats { + cost := totalCost + if cost <= 0 { + return nil + } + return &cost + } + + // 优先级 3:模型定价文件(LiteLLM)默认价格 + if billingService != nil { return tryModelFilePricing(billingService, upstreamModel, tokens) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5267156d..b67a06a7 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7598,6 +7598,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage ImageOutputTokens: result.Usage.ImageOutputTokens, }, 1, // requestCount + cost.TotalCost, ) } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index e060b981..9a6fbb8f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4582,7 +4582,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.AccountStatsCost = resolveAccountStatsCost( ctx, s.channelService, s.billingService, account.ID, *apiKey.GroupID, statsModel, - tokens, 1, + tokens, 1, cost.TotalCost, ) } From 2066c478ab3d408a67a05f98ca08e79c9d480b0c Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 16:52:02 +0800 Subject: [PATCH 49/88] fix(frontend): quota notify UI improvements - QuotaNotifyToggle: add $ or % suffix to threshold input based on type - QuotaLimitCard: combine reset mode and notify toggle on same row to reduce vertical height for daily/weekly sections - Remove redundant ml-4 indentation from QuotaNotifyToggle --- backend/cmd/server/VERSION | 2 +- .../src/components/account/QuotaLimitCard.vue | 74 +++++++++---------- .../components/account/QuotaNotifyToggle.vue | 32 ++++---- 3 files changed, 56 insertions(+), 52 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index c74c5ae7..e52af706 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.27 +0.1.110.28 diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 38ab6479..682f51aa 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -203,28 +203,28 @@ const onWeeklyModeChange = (e: Event) => { :placeholder="t('admin.accounts.quotaLimitPlaceholder')" /> - -
+ +
-
- -
- - + +

- { :placeholder="t('admin.accounts.quotaLimitPlaceholder')" />
- -
+ +
-
- -
- - - - + +

- From ac55443278327228c3f9f1f19045a55cfb9bdbce Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 17:04:23 +0800 Subject: [PATCH 50/88] fix(frontend): collapsible quota card and compact notify layout - QuotaLimitCard: add collapse/expand toggle (chevron icon + click header) - QuotaNotifyToggle: show $ or % suffix in threshold input - Reduce vertical spacing between reset mode hint and notify toggle --- backend/cmd/server/VERSION | 2 +- .../src/components/account/QuotaLimitCard.vue | 42 ++++++++++--------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index e52af706..a7975f29 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.28 +0.1.110.30 diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 682f51aa..832fee96 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -66,15 +66,17 @@ const enabled = computed(() => ) const localEnabled = ref(enabled.value) +const collapsed = ref(false) // Sync when props change externally watch(enabled, (val) => { localEnabled.value = val }) -// When toggle is turned off, clear all values +// When toggle is turned off, clear all values and expand watch(localEnabled, (val) => { if (!val) { + collapsed.value = false emit('update:totalLimit', null) emit('update:dailyLimit', null) emit('update:weeklyLimit', null) @@ -162,13 +164,19 @@ const onWeeklyModeChange = (e: Event) => {

-
-
- $ - +
+ $ + +
+
@@ -304,15 +319,6 @@ const onWeeklyModeChange = (e: Event) => { {{ t('admin.accounts.quotaWeeklyLimitHint') }}

-
@@ -330,28 +336,31 @@ const onWeeklyModeChange = (e: Event) => {
-
- $ - +
+ $ + +
+

{{ t('admin.accounts.quotaTotalLimitHint') }}

-
From 216bda58da5f52cbffaa8deb79f9891184adc315 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 17:38:33 +0800 Subject: [PATCH 52/88] fix: change quota notify threshold semantics to "remaining quota" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Threshold now represents remaining quota instead of usage amount: - Fixed ($): threshold=400, limit=1000 → alert when remaining drops to $400 (i.e., usage reaches $600) - Percentage (%): threshold=30%, limit=1000 → alert when remaining drops to 30% (i.e., usage reaches $700) Also: - Rename 告警阈值 → 提醒阈值 in i18n - Widen type dropdown to w-16 for proper $ / % display --- backend/cmd/server/VERSION | 2 +- .../service/balance_notify_service.go | 15 +++-- .../components/account/QuotaNotifyToggle.vue | 58 ++++++------------- frontend/src/i18n/locales/en.ts | 2 +- frontend/src/i18n/locales/zh.ts | 2 +- 5 files changed, 30 insertions(+), 49 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index f0c65691..01eddd22 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.31 +0.1.110.34 diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index a392a13e..f5abbacc 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -115,13 +115,18 @@ type quotaDim struct { limit float64 } -// resolvedThreshold returns the effective threshold value. -// For percentage type, it computes threshold = limit * percentage / 100. +// resolvedThreshold converts the user-facing "remaining" threshold into a usage-based trigger point. +// The threshold represents how much quota REMAINS when the alert fires: +// - Fixed ($): threshold=400, limit=1000 → fires when usage reaches 600 (remaining drops to 400) +// - Percentage (%): threshold=30, limit=1000 → fires when usage reaches 700 (remaining drops to 30%) func (d quotaDim) resolvedThreshold() float64 { - if d.thresholdType == thresholdTypePercentage && d.limit > 0 { - return d.limit * d.threshold / 100 + if d.limit <= 0 { + return 0 } - return d.threshold + if d.thresholdType == thresholdTypePercentage { + return d.limit * (1 - d.threshold/100) + } + return d.limit - d.threshold } // buildQuotaDims returns the three quota dimensions for notification checking. diff --git a/frontend/src/components/account/QuotaNotifyToggle.vue b/frontend/src/components/account/QuotaNotifyToggle.vue index c7583a01..23979638 100644 --- a/frontend/src/components/account/QuotaNotifyToggle.vue +++ b/frontend/src/components/account/QuotaNotifyToggle.vue @@ -1,8 +1,4 @@ diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 5c593946..dcbcf03c 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2264,7 +2264,7 @@ export default { quotaLimitAmount: 'Total Limit', quotaLimitAmountHint: 'Cumulative spending limit. Does not auto-reset.', quotaNotify: { - alert: 'Alert Threshold', + alert: 'Alert', enabled: 'Enable Alert', threshold: 'Alert Amount', thresholdPlaceholder: 'Enter percentage', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 141193c1..6dc9311c 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2262,7 +2262,7 @@ export default { quotaLimitAmount: '总限额', quotaLimitAmountHint: '累计消费上限,不会自动重置。', quotaNotify: { - alert: '告警阈值', + alert: '提醒阈值', enabled: '启用告警', threshold: '告警金额', thresholdPlaceholder: '输入百分比', From e27335acdd1667c6ae5d35d1bc362412adb98b3f Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 18:23:20 +0800 Subject: [PATCH 53/88] fix(ui): widen notify type dropdown to show % fully, align quota input widths --- backend/cmd/server/VERSION | 2 +- .../src/components/account/QuotaLimitCard.vue | 161 ++++++------------ .../components/account/QuotaNotifyToggle.vue | 2 +- 3 files changed, 54 insertions(+), 111 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 01eddd22..0a81f94d 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.34 +0.1.110.38 diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index ab109a9a..9051b9be 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -196,171 +196,114 @@ const onWeeklyModeChange = (e: Event) => { -
+
- -
-
- $ - + +
+ {{ t('admin.accounts.quotaDailyLimit') }} + {{ t('admin.accounts.quotaNotify.alert') }} +
+ +
+
+ $ +
-
+
-
-

- - +

+ +

- -
-
- $ - +
+ {{ t('admin.accounts.quotaWeeklyLimit') }} + {{ t('admin.accounts.quotaNotify.alert') }} +
+
+
+ $ +
-
+
-
-

- - +

+ +

- +
-
- -
-
- $ - +
+ {{ t('admin.accounts.quotaTotalLimit') }} + {{ t('admin.accounts.quotaNotify.alert') }} +
+
+
+ $ +
-

{{ t('admin.accounts.quotaTotalLimitHint') }}

+

{{ t('admin.accounts.quotaTotalLimitHint') }}

diff --git a/frontend/src/components/account/QuotaNotifyToggle.vue b/frontend/src/components/account/QuotaNotifyToggle.vue index 23979638..0548f661 100644 --- a/frontend/src/components/account/QuotaNotifyToggle.vue +++ b/frontend/src/components/account/QuotaNotifyToggle.vue @@ -42,7 +42,7 @@ const emit = defineEmits<{ +

{{ t('admin.settings.balanceNotify.rechargeUrlHint') }}

+
@@ -3027,6 +3032,7 @@ const form = reactive({ // Balance & quota notification balance_low_notify_enabled: false, balance_low_notify_threshold: 0, + balance_low_notify_recharge_url: '', account_quota_notify_enabled: false, account_quota_notify_emails: [] as NotifyEmailEntry[] }) @@ -3598,6 +3604,7 @@ async function saveSettings() { // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, + balance_low_notify_recharge_url: form.balance_low_notify_recharge_url || '', account_quota_notify_enabled: form.account_quota_notify_enabled, account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } From 48b6c4811f2f8b58933e30c91687acb027aeedf4 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 18:44:36 +0800 Subject: [PATCH 55/88] fix(notify): auto-fill recharge URL with current origin when empty --- backend/cmd/server/VERSION | 2 +- frontend/src/views/admin/SettingsView.vue | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index ee13da53..9e3ea640 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.39 +0.1.110.40 diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index ef67ce22..57f6b35e 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2707,7 +2707,7 @@
- +

{{ t('admin.settings.balanceNotify.rechargeUrlHint') }}

@@ -3262,6 +3262,8 @@ const addQuotaNotifyEmail = () => { form.account_quota_notify_emails.push({ email: '', disabled: false, verified: true }) } +const currentOrigin = typeof window !== 'undefined' ? window.location.origin : '' + // LinuxDo OAuth redirect URL suggestion const linuxdoRedirectUrlSuggestion = computed(() => { if (typeof window === 'undefined') return '' @@ -3604,7 +3606,7 @@ async function saveSettings() { // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, - balance_low_notify_recharge_url: form.balance_low_notify_recharge_url || '', + balance_low_notify_recharge_url: form.balance_low_notify_recharge_url || currentOrigin, account_quota_notify_enabled: form.account_quota_notify_enabled, account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } From f571d8ffad4ac8f3c20d44a70221a0d1f7211d28 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 18:52:02 +0800 Subject: [PATCH 56/88] fix(notify): write back auto-filled recharge URL to form on save --- backend/cmd/server/VERSION | 2 +- frontend/src/views/admin/SettingsView.vue | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 9e3ea640..a69c1172 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.40 +0.1.110.41 diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 57f6b35e..3ef1c0ba 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3606,7 +3606,7 @@ async function saveSettings() { // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, - balance_low_notify_recharge_url: form.balance_low_notify_recharge_url || currentOrigin, + balance_low_notify_recharge_url: (form.balance_low_notify_recharge_url = form.balance_low_notify_recharge_url || currentOrigin), account_quota_notify_enabled: form.account_quota_notify_enabled, account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } From 6e9146e746a8743ad078230aedaabae85ef15e25 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 19:02:40 +0800 Subject: [PATCH 57/88] fix(notify): add recharge URL to admin settings GET response --- backend/cmd/server/VERSION | 2 +- backend/internal/handler/admin/setting_handler.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index a69c1172..ab6fbb6e 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.41 +0.1.110.42 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index e31eb134..bc6d183c 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -178,6 +178,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), PaymentEnabled: paymentCfg.Enabled, From a43da6225449c68f486b796139a4144c9fbe24fc Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 20:01:25 +0800 Subject: [PATCH 58/88] fix(accounts): unify modal width, add notify props to create, fix quota layout - EditAccountModal width changed from "normal" to "wide" (match CreateAccountModal) - CreateAccountModal now passes all quota notify props to QuotaLimitCard - QuotaLimitCard: when global notify disabled, hide title row, input takes full width - Quota alert email: show remaining quota + threshold (fixed/$, percentage/%) instead of usage trigger point --- backend/cmd/server/VERSION | 2 +- .../service/balance_notify_service.go | 35 +++++++---- .../components/account/CreateAccountModal.vue | 61 +++++++++++++++++++ .../components/account/EditAccountModal.vue | 2 +- .../src/components/account/QuotaLimitCard.vue | 23 ++++--- 5 files changed, 99 insertions(+), 24 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index ab6fbb6e..76129f5c 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.42 +0.1.110.44 diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 3951e88f..5191a26e 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -257,7 +257,7 @@ func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, account slog.Error("panic in quota notification", "recover", r) } }() - s.sendQuotaAlertEmails(adminEmails, accountID, accountName, platform, dim.name, newUsed, dim.limit, effectiveThreshold, siteName) + s.sendQuotaAlertEmails(adminEmails, accountID, accountName, platform, dim, newUsed, siteName) }() } @@ -384,15 +384,25 @@ func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userNam } // sendQuotaAlertEmails sends quota alert notification to admin emails. -func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountID int64, accountName, platform, dimension string, used, limit, threshold float64, siteName string) { - dimLabel := quotaDimLabels[dimension] +func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, used float64, siteName string) { + dimLabel := quotaDimLabels[dim.name] if dimLabel == "" { - dimLabel = dimension + dimLabel = dim.name + } + + // Format the remaining-based threshold for display + thresholdDisplay := fmt.Sprintf("$%.2f", dim.threshold) + if dim.thresholdType == thresholdTypePercentage { + thresholdDisplay = fmt.Sprintf("%.0f%%", dim.threshold) + } + remaining := dim.limit - used + if remaining < 0 { + remaining = 0 } subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName)) - body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName)) - s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension) + body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName)) + s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name) } // sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection. @@ -440,7 +450,7 @@ const balanceLowEmailTemplate = ` ` // quotaAlertEmailTemplate is the HTML template for account quota alert notifications. -// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, threshold. +// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay. const quotaAlertEmailTemplate = ` @@ -469,10 +479,11 @@ const quotaAlertEmailTemplate = `
维度 / Dimension%s
已使用 / Used$%.2f
限额 / Limit%s
-
告警阈值 / Threshold$%.2f
+
剩余额度 / Remaining$%.2f
+
提醒阈值 / Alert Threshold%s
-

账号配额用量已达到告警阈值,请及时关注。

-

Account quota usage has reached the alert threshold.

+

账号剩余额度已低于提醒阈值,请及时关注。

+

Account remaining quota has fallen below the alert threshold.

@@ -490,11 +501,11 @@ func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance } // buildQuotaAlertEmailBody builds HTML email for account quota alert. -func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, accountName, platform, dimLabel string, used, limit, threshold float64, siteName string) string { +func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, accountName, platform, dimLabel string, used, limit, remaining float64, thresholdDisplay, siteName string) string { limitStr := fmt.Sprintf("$%.2f", limit) if limit <= 0 { limitStr = "无限制 / Unlimited" } - return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, threshold) + return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay) } diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index a1496fa8..ba7bad51 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1493,6 +1493,15 @@ :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" :quotaNotifyGlobalEnabled="quotaNotifyGlobalEnabled" + :quotaNotifyDailyEnabled="quotaNotifyDailyEnabled" + :quotaNotifyDailyThreshold="quotaNotifyDailyThreshold" + :quotaNotifyDailyThresholdType="quotaNotifyDailyThresholdType" + :quotaNotifyWeeklyEnabled="quotaNotifyWeeklyEnabled" + :quotaNotifyWeeklyThreshold="quotaNotifyWeeklyThreshold" + :quotaNotifyWeeklyThresholdType="quotaNotifyWeeklyThresholdType" + :quotaNotifyTotalEnabled="quotaNotifyTotalEnabled" + :quotaNotifyTotalThreshold="quotaNotifyTotalThreshold" + :quotaNotifyTotalThresholdType="quotaNotifyTotalThresholdType" :dailyResetMode="editDailyResetMode" :dailyResetHour="editDailyResetHour" :weeklyResetMode="editWeeklyResetMode" @@ -1502,6 +1511,15 @@ @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" + @update:quotaNotifyDailyEnabled="quotaNotifyDailyEnabled = $event" + @update:quotaNotifyDailyThreshold="quotaNotifyDailyThreshold = $event" + @update:quotaNotifyDailyThresholdType="quotaNotifyDailyThresholdType = $event" + @update:quotaNotifyWeeklyEnabled="quotaNotifyWeeklyEnabled = $event" + @update:quotaNotifyWeeklyThreshold="quotaNotifyWeeklyThreshold = $event" + @update:quotaNotifyWeeklyThresholdType="quotaNotifyWeeklyThresholdType = $event" + @update:quotaNotifyTotalEnabled="quotaNotifyTotalEnabled = $event" + @update:quotaNotifyTotalThreshold="quotaNotifyTotalThreshold = $event" + @update:quotaNotifyTotalThresholdType="quotaNotifyTotalThresholdType = $event" @update:dailyResetMode="editDailyResetMode = $event" @update:dailyResetHour="editDailyResetHour = $event" @update:weeklyResetMode="editWeeklyResetMode = $event" @@ -1527,6 +1545,15 @@ :dailyLimit="editQuotaDailyLimit" :weeklyLimit="editQuotaWeeklyLimit" :quotaNotifyGlobalEnabled="quotaNotifyGlobalEnabled" + :quotaNotifyDailyEnabled="quotaNotifyDailyEnabled" + :quotaNotifyDailyThreshold="quotaNotifyDailyThreshold" + :quotaNotifyDailyThresholdType="quotaNotifyDailyThresholdType" + :quotaNotifyWeeklyEnabled="quotaNotifyWeeklyEnabled" + :quotaNotifyWeeklyThreshold="quotaNotifyWeeklyThreshold" + :quotaNotifyWeeklyThresholdType="quotaNotifyWeeklyThresholdType" + :quotaNotifyTotalEnabled="quotaNotifyTotalEnabled" + :quotaNotifyTotalThreshold="quotaNotifyTotalThreshold" + :quotaNotifyTotalThresholdType="quotaNotifyTotalThresholdType" :dailyResetMode="editDailyResetMode" :dailyResetHour="editDailyResetHour" :weeklyResetMode="editWeeklyResetMode" @@ -1536,6 +1563,15 @@ @update:totalLimit="editQuotaLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event" + @update:quotaNotifyDailyEnabled="quotaNotifyDailyEnabled = $event" + @update:quotaNotifyDailyThreshold="quotaNotifyDailyThreshold = $event" + @update:quotaNotifyDailyThresholdType="quotaNotifyDailyThresholdType = $event" + @update:quotaNotifyWeeklyEnabled="quotaNotifyWeeklyEnabled = $event" + @update:quotaNotifyWeeklyThreshold="quotaNotifyWeeklyThreshold = $event" + @update:quotaNotifyWeeklyThresholdType="quotaNotifyWeeklyThresholdType = $event" + @update:quotaNotifyTotalEnabled="quotaNotifyTotalEnabled = $event" + @update:quotaNotifyTotalThreshold="quotaNotifyTotalThreshold = $event" + @update:quotaNotifyTotalThresholdType="quotaNotifyTotalThresholdType = $event" @update:dailyResetMode="editDailyResetMode = $event" @update:dailyResetHour="editDailyResetHour = $event" @update:weeklyResetMode="editWeeklyResetMode = $event" @@ -3041,6 +3077,15 @@ const anthropicPassthroughEnabled = ref(false) const webSearchEmulationMode = ref('default') const webSearchGlobalEnabled = ref(false) const quotaNotifyGlobalEnabled = ref(false) +const quotaNotifyDailyEnabled = ref(null) +const quotaNotifyDailyThreshold = ref(null) +const quotaNotifyDailyThresholdType = ref(null) +const quotaNotifyWeeklyEnabled = ref(null) +const quotaNotifyWeeklyThreshold = ref(null) +const quotaNotifyWeeklyThresholdType = ref(null) +const quotaNotifyTotalEnabled = ref(null) +const quotaNotifyTotalThreshold = ref(null) +const quotaNotifyTotalThresholdType = ref(null) // Load global feature states once adminAPI.settings.getWebSearchEmulationConfig().then(cfg => { @@ -4153,6 +4198,22 @@ const createAccountAndFinish = async ( if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') { quotaExtra.quota_reset_timezone = editResetTimezone.value || 'UTC' } + // Quota notify config + if (quotaNotifyDailyEnabled.value) { + quotaExtra.quota_notify_daily_enabled = true + if (quotaNotifyDailyThreshold.value != null) quotaExtra.quota_notify_daily_threshold = quotaNotifyDailyThreshold.value + quotaExtra.quota_notify_daily_threshold_type = quotaNotifyDailyThresholdType.value || 'fixed' + } + if (quotaNotifyWeeklyEnabled.value) { + quotaExtra.quota_notify_weekly_enabled = true + if (quotaNotifyWeeklyThreshold.value != null) quotaExtra.quota_notify_weekly_threshold = quotaNotifyWeeklyThreshold.value + quotaExtra.quota_notify_weekly_threshold_type = quotaNotifyWeeklyThresholdType.value || 'fixed' + } + if (quotaNotifyTotalEnabled.value) { + quotaExtra.quota_notify_total_enabled = true + if (quotaNotifyTotalThreshold.value != null) quotaExtra.quota_notify_total_threshold = quotaNotifyTotalThreshold.value + quotaExtra.quota_notify_total_threshold_type = quotaNotifyTotalThresholdType.value || 'fixed' + } if (Object.keys(quotaExtra).length > 0) { finalExtra = quotaExtra } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 613738d2..92761b35 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -2,7 +2,7 @@
{
- -
+ +
{{ t('admin.accounts.quotaDailyLimit') }} - {{ t('admin.accounts.quotaNotify.alert') }} + {{ t('admin.accounts.quotaNotify.alert') }}
+
-
+
$
@@ -238,12 +239,13 @@ const onWeeklyModeChange = (e: Event) => {
-
+
{{ t('admin.accounts.quotaWeeklyLimit') }} - {{ t('admin.accounts.quotaNotify.alert') }} + {{ t('admin.accounts.quotaNotify.alert') }}
+
-
+
$
@@ -287,12 +289,13 @@ const onWeeklyModeChange = (e: Event) => {
-
+
{{ t('admin.accounts.quotaTotalLimit') }} - {{ t('admin.accounts.quotaNotify.alert') }} + {{ t('admin.accounts.quotaNotify.alert') }}
+
-
+
$
From ca673f98995b286c35b19b659088c5970b4eea54 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 20:35:38 +0800 Subject: [PATCH 59/88] test: add 66 unit tests for balance/quota notify + plan validation balance_notify_service_test.go (27 tests): - resolveBalanceThreshold: fixed/percentage/zero recharged/empty type - quotaDim.resolvedThreshold: fixed normal/exceed/equal limit, percentage 0/30/100/>100, zero/negative limit - sanitizeEmailHeader: CRLF/CR/LF/clean/empty/multiple newlines - buildQuotaDims / buildQuotaDimsFromState: all dimensions, empty extra, state-vs-account precedence - collectBalanceNotifyRecipients: empty, filter disabled/unverified, case-insensitive dedup, skip empty, trim balance_notify_check_test.go (16 tests): - CheckBalanceAfterDeduction guard clauses: nil user/disabled/global-off/threshold=0/user-override/no-crossing - CheckAccountQuotaAfterIncrement guards: nil account/zero cost/negative cost/global-disabled - getBalanceNotifyConfig: all fields, disabled, invalid threshold - isAccountQuotaNotifyEnabled: missing/false/true - getSiteName: default fallback + configured balance_notify_email_body_test.go (10 tests): - Guards against fmt.Sprintf arg-count mismatches in email templates - Verifies HTML escaping of recharge URL - Verifies CSS %% escape produces literal % in output - Verifies unlimited/percentage/over-quota display branches payment_config_plans_validation_test.go (13 tests): - validatePlanRequired: all 5 validation branches + whitespace handling --- .../service/balance_notify_check_test.go | 180 +++++++++++ .../service/balance_notify_email_body_test.go | 147 +++++++++ .../service/balance_notify_service_test.go | 280 ++++++++++++++++++ .../payment_config_plans_validation_test.go | 89 ++++++ 4 files changed, 696 insertions(+) create mode 100644 backend/internal/service/balance_notify_check_test.go create mode 100644 backend/internal/service/balance_notify_email_body_test.go create mode 100644 backend/internal/service/balance_notify_service_test.go create mode 100644 backend/internal/service/payment_config_plans_validation_test.go diff --git a/backend/internal/service/balance_notify_check_test.go b/backend/internal/service/balance_notify_check_test.go new file mode 100644 index 00000000..955f3129 --- /dev/null +++ b/backend/internal/service/balance_notify_check_test.go @@ -0,0 +1,180 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an +// in-memory settings repo and a non-nil emailService so that the guard-clause +// nil-checks pass. The emailService is intentionally minimal — tests must +// avoid crossing scenarios that would actually dispatch emails. +func newBalanceNotifyServiceForTest() (*BalanceNotifyService, *mockSettingRepo) { + repo := newMockSettingRepo() + // EmailService is a concrete type; construct with the same repo so that + // any accidental fallback reads still succeed. Tests should not trigger a + // crossing that reaches SendEmail. + email := NewEmailService(repo, nil) + return NewBalanceNotifyService(email, repo, nil), repo +} + +// ---------- guard clauses ---------- + +func TestCheckBalanceAfterDeduction_NilUser(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + // Should not panic. + s.CheckBalanceAfterDeduction(context.Background(), nil, 100, 50) +} + +func TestCheckBalanceAfterDeduction_UserNotifyDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "10" + u := &User{ID: 1, BalanceNotifyEnabled: false} + // Even with a crossing, disabled flag short-circuits. + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_GlobalDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "false" + u := &User{ID: 1, BalanceNotifyEnabled: true} + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_ThresholdZero(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "0" + u := &User{ID: 1, BalanceNotifyEnabled: true} + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_UserThresholdOverride(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "100" // global default + customThreshold := 5.0 + u := &User{ + ID: 1, + BalanceNotifyEnabled: true, + BalanceNotifyThreshold: &customThreshold, + } + // User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not + // cross 5, so nothing fires (verified by absence of panic). + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_NoCrossingNotFired(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "10" + u := &User{ID: 1, BalanceNotifyEnabled: true} + + // 100 -> 95, both remain above threshold=10, no crossing. + s.CheckBalanceAfterDeduction(context.Background(), u, 100, 5) + // 5 -> 3, both already below threshold, no crossing (only fires on first + // cross from above-to-below). + s.CheckBalanceAfterDeduction(context.Background(), u, 5, 2) +} + +// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ---------- + +func TestCheckAccountQuotaAfterIncrement_NilAccount(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + // Should not panic. + s.CheckAccountQuotaAfterIncrement(context.Background(), nil, 10, nil) +} + +func TestCheckAccountQuotaAfterIncrement_ZeroCost(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + s.CheckAccountQuotaAfterIncrement(context.Background(), a, 0, nil) +} + +func TestCheckAccountQuotaAfterIncrement_NegativeCost(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + s.CheckAccountQuotaAfterIncrement(context.Background(), a, -5, nil) +} + +func TestCheckAccountQuotaAfterIncrement_GlobalDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false" + a := &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_notify_daily_enabled": true, + "quota_notify_daily_threshold": 100.0, + "quota_daily_limit": 1000.0, + "quota_daily_used": 950.0, + }, + } + // Global disabled → no processing even if a dim would cross. + s.CheckAccountQuotaAfterIncrement(context.Background(), a, 100, nil) +} + +// ---------- sanity: internal helpers still work ---------- + +func TestGetBalanceNotifyConfig_AllFields(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "12.5" + repo.data[SettingKeyBalanceLowNotifyRechargeURL] = "https://example.com/pay" + + enabled, threshold, url := s.getBalanceNotifyConfig(context.Background()) + require.True(t, enabled) + require.Equal(t, 12.5, threshold) + require.Equal(t, "https://example.com/pay", url) +} + +func TestGetBalanceNotifyConfig_Disabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "false" + + enabled, _, _ := s.getBalanceNotifyConfig(context.Background()) + require.False(t, enabled) +} + +func TestGetBalanceNotifyConfig_InvalidThreshold(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "not-a-number" + + enabled, threshold, _ := s.getBalanceNotifyConfig(context.Background()) + require.True(t, enabled) + require.Equal(t, 0.0, threshold) +} + +func TestIsAccountQuotaNotifyEnabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + + // Missing key → false + require.False(t, s.isAccountQuotaNotifyEnabled(context.Background())) + + // Explicit "false" + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false" + require.False(t, s.isAccountQuotaNotifyEnabled(context.Background())) + + // Explicit "true" + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "true" + require.True(t, s.isAccountQuotaNotifyEnabled(context.Background())) +} + +func TestGetSiteName_FallsBackToDefault(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + name := s.getSiteName(context.Background()) + require.Equal(t, defaultSiteName, name) +} + +func TestGetSiteName_Configured(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeySiteName] = "My Site" + require.Equal(t, "My Site", s.getSiteName(context.Background())) +} diff --git a/backend/internal/service/balance_notify_email_body_test.go b/backend/internal/service/balance_notify_email_body_test.go new file mode 100644 index 00000000..9baf164e --- /dev/null +++ b/backend/internal/service/balance_notify_email_body_test.go @@ -0,0 +1,147 @@ +//go:build unit + +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// These tests guard against fmt.Sprintf arg-count mismatches in the email +// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in +// the output, which these assertions will catch. + +// ---------- buildBalanceLowEmailBody ---------- + +func TestBuildBalanceLowEmailBody_ContainsRequiredFields(t *testing.T) { + s := &BalanceNotifyService{} + body := s.buildBalanceLowEmailBody("Alice", 3.14, 10.0, "MySite", "") + + // All substituted values should appear in the output. + require.Contains(t, body, "MySite") + require.Contains(t, body, "Alice") + require.Contains(t, body, "$3.14") + require.Contains(t, body, "$10.00") + + // No fmt.Sprintf format error markers. + require.NotContains(t, body, "%!") + require.NotContains(t, body, "MISSING") + require.NotContains(t, body, "EXTRA") +} + +func TestBuildBalanceLowEmailBody_WithRechargeURL(t *testing.T) { + s := &BalanceNotifyService{} + body := s.buildBalanceLowEmailBody("Bob", 5.0, 20.0, "Site", "https://example.com/pay") + + // The recharge anchor element should appear with the URL. + require.Contains(t, body, `href="https://example.com/pay"`) + require.Contains(t, body, "立即充值") + require.NotContains(t, body, "%!") +} + +func TestBuildBalanceLowEmailBody_RechargeURLEscaped(t *testing.T) { + s := &BalanceNotifyService{} + // Try a URL with characters that need HTML escaping. + body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", `https://example.com/?a=1&b= + + diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 5f0c7c2c..77e437a8 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -1,7 +1,7 @@ diff --git a/frontend/src/composables/useQuotaNotifyState.ts b/frontend/src/composables/useQuotaNotifyState.ts new file mode 100644 index 00000000..1c6705d3 --- /dev/null +++ b/frontend/src/composables/useQuotaNotifyState.ts @@ -0,0 +1,69 @@ +import { reactive, ref } from 'vue' +import { adminAPI } from '@/api/admin' +import { QUOTA_THRESHOLD_TYPE_FIXED } from '@/constants/account' + +export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const +export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number] + +interface DimState { + enabled: boolean | null + threshold: number | null + thresholdType: string | null +} + +export function useQuotaNotifyState() { + const globalEnabled = ref(false) + const state = reactive>({ + daily: { enabled: null, threshold: null, thresholdType: null }, + weekly: { enabled: null, threshold: null, thresholdType: null }, + total: { enabled: null, threshold: null, thresholdType: null }, + }) + + function loadGlobalState() { + adminAPI.settings + .getSettings() + .then((settings) => { + globalEnabled.value = settings.account_quota_notify_enabled === true + }) + .catch(() => { + globalEnabled.value = false + }) + } + + function loadFromExtra(extra: Record | null | undefined) { + for (const d of QUOTA_NOTIFY_DIMS) { + state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null + state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null + state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as string) ?? null + } + } + + function writeToExtra(extra: Record, mode: 'create' | 'update') { + for (const d of QUOTA_NOTIFY_DIMS) { + const s = state[d] + if (s.enabled) { + extra[`quota_notify_${d}_enabled`] = true + if (s.threshold != null) { + extra[`quota_notify_${d}_threshold`] = s.threshold + } else if (mode === 'update') { + delete extra[`quota_notify_${d}_threshold`] + } + extra[`quota_notify_${d}_threshold_type`] = s.thresholdType || QUOTA_THRESHOLD_TYPE_FIXED + } else if (mode === 'update') { + delete extra[`quota_notify_${d}_enabled`] + delete extra[`quota_notify_${d}_threshold`] + delete extra[`quota_notify_${d}_threshold_type`] + } + } + } + + function reset() { + for (const d of QUOTA_NOTIFY_DIMS) { + state[d].enabled = null + state[d].threshold = null + state[d].thresholdType = null + } + } + + return { globalEnabled, state, loadGlobalState, loadFromExtra, writeToExtra, reset } +} diff --git a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue index 639b4f66..28b82da5 100644 --- a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue +++ b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue @@ -29,7 +29,7 @@ @@ -67,86 +67,14 @@
- - -
-
- - -
-
- - -
-
- - -
-
- -
-
-
{{ t('payment.admin.dailyLimit') }}: {{ selectedGroupInfo.daily_limit_usd != null ? '$' + selectedGroupInfo.daily_limit_usd : t('payment.admin.unlimited') }}
-
{{ t('payment.admin.weeklyLimit') }}: {{ selectedGroupInfo.weekly_limit_usd != null ? '$' + selectedGroupInfo.weekly_limit_usd : t('payment.admin.unlimited') }}
-
{{ t('payment.admin.monthlyLimit') }}: {{ selectedGroupInfo.monthly_limit_usd != null ? '$' + selectedGroupInfo.monthly_limit_usd : t('payment.admin.unlimited') }}
-
-
- -
-
-
-
-
-
-
-
-
-

{{ t('payment.admin.featuresHint') }}

-
-
- - -
- - - + From 74f8a30f861f2b5072f5916265abbf61d3448b12 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 13 Apr 2026 23:35:59 +0800 Subject: [PATCH 64/88] fix: address audit findings for websearch, email verification, and pricing - Fix websearch provider failover: proxy error from provider-specific proxy now continues to next provider instead of aborting the entire loop - Fix SMTP failure locking users out: send email first, then write cache and increment rate counter - Fix notify email cache key case sensitivity: normalize to lowercase - Add OriginalPrice validation to validatePlanPatch and validatePlanRequired - Add empty scope validation for channel pricing rules (group_ids/account_ids) - Add platform color to account search dropdown in channel pricing rules --- .../internal/handler/admin/channel_handler.go | 11 +++ backend/internal/pkg/websearch/manager.go | 13 +++- backend/internal/repository/email_cache.go | 5 +- .../internal/service/payment_config_plans.go | 10 ++- .../payment_config_plans_validation_test.go | 75 ++++++++++++++----- backend/internal/service/user_service.go | 8 +- frontend/src/views/admin/ChannelsView.vue | 7 +- 7 files changed, 103 insertions(+), 26 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 2d4cd56a..ee76a750 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "strconv" "strings" @@ -351,6 +352,11 @@ func (h *ChannelHandler) Create(c *gin.Context) { var statsRules []service.AccountStatsPricingRule for i, r := range req.AccountStatsPricingRules { + if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE", + fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) @@ -409,6 +415,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { if req.AccountStatsPricingRules != nil { statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules)) for i, r := range *req.AccountStatsPricingRules { + if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE", + fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index ae0683ad..27592459 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -111,9 +111,18 @@ func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) } if isProxyError(err) { m.markProxyUnavailable(ctx, cfg, req.ProxyURL) - slog.Warn("websearch: proxy error, marking unavailable", + if req.ProxyURL != "" { + // Account-level proxy is shared by all providers — no point + // trying others with the same broken proxy; signal account switch. + slog.Warn("websearch: account proxy error, aborting failover", + "provider", cfg.Type, "error", err) + return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error()) + } + // Provider-specific proxy failed — try the next provider which + // may use a different (or no) proxy. + slog.Warn("websearch: provider proxy error, trying next provider", "provider", cfg.Type, "error", err) - return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error()) + continue } slog.Warn("websearch: provider search failed", "provider", cfg.Type, "error", err) diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index ed903e0d..1356163d 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -24,8 +25,10 @@ func verifyCodeKey(email string) string { } // notifyVerifyKey generates the Redis key for notify email verification code. +// Email is lowercased to prevent case-sensitive key mismatch (the business layer +// uses strings.EqualFold for comparison). func notifyVerifyKey(email string) string { - return notifyVerifyKeyPrefix + email + return notifyVerifyKeyPrefix + strings.ToLower(email) } // passwordResetKey generates the Redis key for password reset token. diff --git a/backend/internal/service/payment_config_plans.go b/backend/internal/service/payment_config_plans.go index 8a5e1924..6753071d 100644 --- a/backend/internal/service/payment_config_plans.go +++ b/backend/internal/service/payment_config_plans.go @@ -12,7 +12,7 @@ import ( ) // validatePlanRequired checks that all required fields for a plan are provided. -func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string) error { +func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string, originalPrice *float64) error { if strings.TrimSpace(name) == "" { return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required") } @@ -28,6 +28,9 @@ func validatePlanRequired(name string, groupID int64, price float64, validityDay if strings.TrimSpace(validityUnit) == "" { return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required") } + if originalPrice != nil && *originalPrice < 0 { + return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0") + } return nil } @@ -48,6 +51,9 @@ func validatePlanPatch(req UpdatePlanRequest) error { if req.ValidityUnit != nil && strings.TrimSpace(*req.ValidityUnit) == "" { return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required") } + if req.OriginalPrice != nil && *req.OriginalPrice < 0 { + return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0") + } return nil } @@ -115,7 +121,7 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S } func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) { - if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit); err != nil { + if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit, req.OriginalPrice); err != nil { return nil, err } b := s.entClient.SubscriptionPlan.Create(). diff --git a/backend/internal/service/payment_config_plans_validation_test.go b/backend/internal/service/payment_config_plans_validation_test.go index bc9c0048..9a2d8716 100644 --- a/backend/internal/service/payment_config_plans_validation_test.go +++ b/backend/internal/service/payment_config_plans_validation_test.go @@ -9,81 +9,122 @@ import ( ) func TestValidatePlanRequired_AllValid(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, "days") + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", nil) require.NoError(t, err) } func TestValidatePlanRequired_EmptyName(t *testing.T) { - err := validatePlanRequired("", 1, 9.99, 30, "days") + err := validatePlanRequired("", 1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_WhitespaceName(t *testing.T) { - err := validatePlanRequired(" ", 1, 9.99, 30, "days") + err := validatePlanRequired(" ", 1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_ZeroGroupID(t *testing.T) { - err := validatePlanRequired("Pro", 0, 9.99, 30, "days") + err := validatePlanRequired("Pro", 0, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "group") } func TestValidatePlanRequired_NegativeGroupID(t *testing.T) { - err := validatePlanRequired("Pro", -1, 9.99, 30, "days") + err := validatePlanRequired("Pro", -1, 9.99, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "group") } func TestValidatePlanRequired_ZeroPrice(t *testing.T) { - err := validatePlanRequired("Pro", 1, 0, 30, "days") + err := validatePlanRequired("Pro", 1, 0, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "price") } func TestValidatePlanRequired_NegativePrice(t *testing.T) { - err := validatePlanRequired("Pro", 1, -5, 30, "days") + err := validatePlanRequired("Pro", 1, -5, 30, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "price") } func TestValidatePlanRequired_ZeroValidityDays(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 0, "days") + err := validatePlanRequired("Pro", 1, 9.99, 0, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity days") } func TestValidatePlanRequired_NegativeValidityDays(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, -7, "days") + err := validatePlanRequired("Pro", 1, 9.99, -7, "days", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity days") } func TestValidatePlanRequired_EmptyValidityUnit(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, "") + err := validatePlanRequired("Pro", 1, 9.99, 30, "", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity unit") } func TestValidatePlanRequired_WhitespaceValidityUnit(t *testing.T) { - err := validatePlanRequired("Pro", 1, 9.99, 30, " ") + err := validatePlanRequired("Pro", 1, 9.99, 30, " ", nil) require.Error(t, err) require.Contains(t, err.Error(), "validity unit") } func TestValidatePlanRequired_NameValidatedFirst(t *testing.T) { - // When multiple fields are invalid, name should be reported first - // (follows the order of checks in the function). - err := validatePlanRequired("", 0, 0, 0, "") + err := validatePlanRequired("", 0, 0, 0, "", nil) require.Error(t, err) require.Contains(t, err.Error(), "plan name") } func TestValidatePlanRequired_TrimmedValidName(t *testing.T) { - // Whitespace-surrounded but non-empty name is accepted (trimmed check only - // rejects pure whitespace). - err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days") + err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days", nil) + require.NoError(t, err) +} + +func TestValidatePlanRequired_NegativeOriginalPrice(t *testing.T) { + neg := -10.0 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &neg) + require.Error(t, err) + require.Contains(t, err.Error(), "original price") +} + +func TestValidatePlanRequired_ZeroOriginalPrice(t *testing.T) { + zero := 0.0 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &zero) + require.NoError(t, err) +} + +func TestValidatePlanRequired_ValidOriginalPrice(t *testing.T) { + op := 19.99 + err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &op) + require.NoError(t, err) +} + +// --- validatePlanPatch tests --- + +func TestValidatePlanPatch_NegativeOriginalPrice(t *testing.T) { + neg := -5.0 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &neg}) + require.Error(t, err) + require.Contains(t, err.Error(), "original price") +} + +func TestValidatePlanPatch_ZeroOriginalPrice(t *testing.T) { + zero := 0.0 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &zero}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_ValidOriginalPrice(t *testing.T) { + op := 29.99 + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &op}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil}) require.NoError(t, err) } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 0da73762..7602d162 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -291,6 +291,12 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema return fmt.Errorf("generate code: %w", err) } + // Send email first — if SMTP fails, don't write cache or increment counters, + // so the user is not locked out by cooldown/rate-limit for a code they never received. + if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil { + return err + } + if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil { return err } @@ -300,7 +306,7 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err) } - return s.sendNotifyVerifyEmail(ctx, emailService, email, code) + return nil } // checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit. diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 2ca1141d..60704b65 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -511,7 +511,7 @@ :class="{ 'opacity-50': rule.account_ids.includes(account.id) }" :disabled="rule.account_ids.includes(account.id)" > - {{ account.name }} + {{ account.name }} #{{ account.id }}
@@ -595,6 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types' import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types' import type { AdminGroup, GroupPlatform } from '@/types' import type { Column } from '@/components/common/types' +import { platformTextClass } from '@/utils/platformColors' import AppLayout from '@/components/layout/AppLayout.vue' import TablePageLayout from '@/components/layout/TablePageLayout.vue' import DataTable from '@/components/common/DataTable.vue' @@ -911,7 +912,7 @@ function getGroupNameById(groupId: number): string { } // ── Account search for pricing rules ── -interface SimpleAccount { id: number; name: string } +interface SimpleAccount { id: number; name: string; platform: string } const ruleAccountSearchKeyword = ref>({}) const ruleAccountSearchResults = ref>({}) @@ -924,7 +925,7 @@ const ruleAccountSearchRunner = useKeyedDebouncedSearch({ search: async (keyword, { key, signal }) => { const platform = key.split('-')[0] const res = await adminAPI.accounts.list(1, 20, { platform, search: keyword }, { signal }) - return res.items.map(a => ({ id: a.id, name: a.name })) + return res.items.map(a => ({ id: a.id, name: a.name, platform: a.platform })) }, onSuccess: (key, result) => { ruleAccountSearchResults.value[key] = result }, onError: (key) => { ruleAccountSearchResults.value[key] = [] }, From a9880ee7b92365482154f986ea4b4d5256ffd7ec Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 00:26:20 +0800 Subject: [PATCH 65/88] =?UTF-8?q?fix:=20round-2=20audit=20fixes=20?= =?UTF-8?q?=E2=80=94=20security,=20code=20quality,=20and=20UI=20improvemen?= =?UTF-8?q?ts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Security (HIGH): - Normalize all Redis cache keys to lowercase (verifyCode, passwordReset) - Fix verify code TTL renewal on failed attempts: use remaining TTL via ExpiresAt field instead of resetting to full 15-minute window - Add 3 missing fields to diffSettings audit log (promo_code, invitation_code, custom_endpoints) Code quality (MEDIUM): - Extract filterVerifiedEmails shared helper (balance_notify_service.go) - Add Pricing array non-empty validation for channel pricing rules - Add platform token semantics comment in gateway_service.go - Complete validatePlanPatch test coverage (+10 test cases) - Replace string types with QuotaThresholdType/QuotaResetMode across frontend - Remove duplicate getPlatformTextColor/getRateBadgeClass in ChannelsView - Return EMAIL_NOT_FOUND error on RemoveNotifyEmail miss UI improvements: - Reorder cost tooltip: user billing above separator, account billing below - Add NaN guard to accountBilled function - Move timezone selector inline into reset-mode row (no longer standalone) --- .../internal/handler/admin/channel_handler.go | 10 + .../internal/handler/admin/setting_handler.go | 9 + backend/internal/repository/email_cache.go | 7 +- .../service/balance_notify_service.go | 44 +- backend/internal/service/email_service.go | 8 +- backend/internal/service/gateway_service.go | 591 +++++++++++++----- .../payment_config_plans_validation_test.go | 63 ++ backend/internal/service/user_service.go | 15 +- .../components/account/QuotaDimensionRow.vue | 28 +- .../src/components/account/QuotaLimitCard.vue | 48 +- .../components/account/QuotaNotifyToggle.vue | 8 +- .../src/components/admin/usage/UsageTable.vue | 12 +- .../src/composables/useQuotaNotifyState.ts | 6 +- frontend/src/constants/account.ts | 5 + frontend/src/views/admin/ChannelsView.vue | 42 +- 15 files changed, 605 insertions(+), 291 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index ee76a750..1a328551 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -357,6 +357,11 @@ func (h *ChannelHandler) Create(c *gin.Context) { fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) return } + if len(r.Pricing) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING", + fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) @@ -420,6 +425,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) return } + if len(r.Pricing) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING", + fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1))) + return + } rule := accountStatsPricingRuleRequestToService(r) rule.SortOrder = i statsRules = append(statsRules, rule) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 0c1606ea..2324cc70 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1138,6 +1138,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { changed = append(changed, "registration_email_suffix_whitelist") } + if before.PromoCodeEnabled != after.PromoCodeEnabled { + changed = append(changed, "promo_code_enabled") + } + if before.InvitationCodeEnabled != after.InvitationCodeEnabled { + changed = append(changed, "invitation_code_enabled") + } if before.PasswordResetEnabled != after.PasswordResetEnabled { changed = append(changed, "password_reset_enabled") } @@ -1348,6 +1354,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.CustomMenuItems != after.CustomMenuItems { changed = append(changed, "custom_menu_items") } + if before.CustomEndpoints != after.CustomEndpoints { + changed = append(changed, "custom_endpoints") + } if before.EnableFingerprintUnification != after.EnableFingerprintUnification { changed = append(changed, "enable_fingerprint_unification") } diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index 1356163d..0eb6bef1 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -20,8 +20,9 @@ const ( ) // verifyCodeKey generates the Redis key for email verification code. +// Email is lowercased for case-insensitive consistency. func verifyCodeKey(email string) string { - return verifyCodeKeyPrefix + email + return verifyCodeKeyPrefix + strings.ToLower(email) } // notifyVerifyKey generates the Redis key for notify email verification code. @@ -33,12 +34,12 @@ func notifyVerifyKey(email string) string { // passwordResetKey generates the Redis key for password reset token. func passwordResetKey(email string) string { - return passwordResetKeyPrefix + email + return passwordResetKeyPrefix + strings.ToLower(email) } // passwordResetSentAtKey generates the Redis key for password reset email sent timestamp. func passwordResetSentAtKey(email string) string { - return passwordResetSentAtKeyPrefix + email + return passwordResetSentAtKeyPrefix + strings.ToLower(email) } type emailCache struct { diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 9a75d6be..5e9afcc8 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -283,6 +283,20 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) return nil } + return filterVerifiedEmails(entries) +} + +// getSiteName reads site name from settings with fallback. +func (s *BalanceNotifyService) getSiteName(ctx context.Context) string { + name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) + if err != nil || name == "" { + return defaultSiteName + } + return name +} + +// filterVerifiedEmails returns deduplicated, non-disabled, verified emails. +func filterVerifiedEmails(entries []NotifyEmailEntry) []string { var recipients []string seen := make(map[string]bool) for _, entry := range entries { @@ -303,38 +317,10 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) return recipients } -// getSiteName reads site name from settings with fallback. -func (s *BalanceNotifyService) getSiteName(ctx context.Context) string { - name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) - if err != nil || name == "" { - return defaultSiteName - } - return name -} - // collectBalanceNotifyRecipients returns verified, non-disabled email recipients. // Only emails with verified=true and disabled=false are included. func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string { - var recipients []string - seen := make(map[string]bool) - - for _, entry := range user.BalanceNotifyExtraEmails { - if entry.Disabled || !entry.Verified { - continue - } - email := strings.TrimSpace(entry.Email) - if email == "" { - continue - } - lower := strings.ToLower(email) - if seen[lower] { - continue - } - seen[lower] = true - recipients = append(recipients, email) - } - - return recipients + return filterVerifiedEmails(user.BalanceNotifyExtraEmails) } // sendEmails sends an email to all recipients with shared timeout and error logging. diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 50d324f2..a94e0dde 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -55,6 +55,7 @@ type VerificationCodeData struct { Code string Attempts int CreatedAt time.Time + ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts } // PasswordResetTokenData represents password reset token data @@ -263,6 +264,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin Code: code, Attempts: 0, CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(verifyCodeTTL), } if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { return fmt.Errorf("save verify code: %w", err) @@ -295,7 +297,11 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证码不匹配 (constant-time comparison to prevent timing attacks) if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { data.Attempts++ - if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { + remaining := time.Until(data.ExpiresAt) + if remaining <= 0 { + return ErrInvalidVerifyCode + } + if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil { slog.Error("failed to update verification attempt count", "email", email, "error", err) } if data.Attempts >= maxVerifyCodeAttempts { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b67a06a7..c65e828a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1194,12 +1194,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { - return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } // antigravity 分组、强制平台模式或无分组使用单平台选择 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 - return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. @@ -1275,11 +1283,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro localExcluded[account.ID] = struct{}{} // 排除此账号 continue // 重新选择 } - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } // 对于等待计划的情况,也需要先检查会话限制 @@ -1291,26 +1295,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } } @@ -1433,53 +1431,76 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if s.isAccountSchedulableForSelection(stickyAccount) && + var stickyCacheMissReason string + + gatePass := s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForQuota(stickyAccount) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) - s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 + rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) + + if rpmPass { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 + stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } - return &AccountSelectionResult{ - Account: stickyAccount, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil) } } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - // 会话限制已满,继续到负载感知选择 + if stickyCacheMissReason == "" { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + stickyCacheMissReason = "session_limit" + // 会话限制已满,继续到负载感知选择 + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } else { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + stickyCacheMissReason = "wait_queue_full" } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 + } else if !gatePass { + stickyCacheMissReason = "gate_check" + } else { + stickyCacheMissReason = "rpm_red" + } + + // 记录粘性缓存未命中的结构化日志 + if stickyCacheMissReason != "" { + baseRPM := stickyAccount.GetBaseRPM() + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { + currentRPM = count + } + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", + stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM) } } else { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", + stickyAccountID, shortSessionHash(sessionHash)) } } } @@ -1544,11 +1565,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil) } } @@ -1561,15 +1578,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return &AccountSelectionResult{ - Account: item.account, - WaitPlan: &AccountWaitPlan{ - AccountID: item.account.ID, - MaxConcurrency: item.account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } @@ -1603,11 +1617,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + if s.cache != nil { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) + } + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } } @@ -1617,15 +1630,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { // 会话限制已满,继续到 Layer 2 } else { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } } @@ -1684,7 +1694,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { + if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil { + return nil, legacyErr + } else if ok { return result, nil } } else { @@ -1723,11 +1735,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil) } } @@ -1750,20 +1758,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, acc, sessionHash) { continue // 会话限制已满,尝试下一个账号 } - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } return nil, ErrNoAvailableAccounts } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -1778,15 +1783,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, true + selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil) + if err != nil { + return nil, false, err + } + return selection, true, nil } } - return nil, false + return nil, false, nil } func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -2401,6 +2406,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } +func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID) + } + return hydrated, nil +} + +func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { + hydrated, err := s.hydrateSelectedAccount(ctx, account) + if err != nil { + return nil, err + } + return &AccountSelectionResult{ + Account: hydrated, + Acquired: acquired, + ReleaseFunc: release, + WaitPlan: waitPlan, + }, nil +} + // filterByMinPriority 过滤出优先级最小的账号集合 func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { if len(accounts) == 0 { @@ -2676,6 +2708,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -2747,6 +2785,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2852,6 +2896,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2918,6 +2968,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + var accounts []Account accountsLoaded := false @@ -2985,6 +3041,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3078,6 +3140,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -3090,6 +3153,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3257,8 +3326,7 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "excluded"} } if !s.isAccountSchedulableForSelection(acc) { - detail := "generic_unschedulable" - return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} } if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { return selectionFailureDiagnosis{ @@ -3282,7 +3350,6 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "eligible"} } -// GetAccessToken 获取账号凭证 func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { if acc == nil { return true @@ -3653,6 +3720,86 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { return result } +// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages, +// system 字段仅保留 Claude Code 标识提示词。 +// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词 +// 无法通过检测,因为后续内容仍为非 Claude Code 格式。 +// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。 +func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { + system = normalizeSystemParam(system) + + // 1. 提取原始 system prompt 文本 + var originalSystemText string + switch v := system.(type) { + case string: + originalSystemText = strings.TrimSpace(v) + case []any: + var parts []string + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, text) + } + } + } + originalSystemText = strings.Join(parts, "\n\n") + } + + // 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致) + // 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。 + // 使用 string 格式会被 Anthropic 检测为第三方应用。 + claudeCodeSystemBlock := []map[string]any{ + { + "type": "text", + "text": claudeCodeSystemPrompt, + "cache_control": map[string]string{"type": "ephemeral"}, + }, + } + out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock) + if !ok { + logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") + return body + } + + // 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头 + // 模型仍通过 messages 接收完整指令,保留客户端功能 + ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt) + if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) { + instrMsg, err1 := json.Marshal(map[string]any{ + "role": "user", + "content": []map[string]any{ + {"type": "text", "text": "[System Instructions]\n" + originalSystemText}, + }, + }) + ackMsg, err2 := json.Marshal(map[string]any{ + "role": "assistant", + "content": []map[string]any{ + {"type": "text", "text": "Understood. I will follow these instructions."}, + }, + }) + if err1 != nil || err2 != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection") + return out + } + + // 重建 messages 数组:[instruction, ack, ...originalMessages] + items := [][]byte{instrMsg, ackMsg} + messagesResult := gjson.GetBytes(out, "messages") + if messagesResult.IsArray() { + messagesResult.ForEach(func(_, msg gjson.Result) bool { + items = append(items, []byte(msg.Raw)) + return true + }) + } + + if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk { + out = next + } + } + + return out +} + type cacheControlPath struct { path string log string @@ -3819,7 +3966,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Always overwrite the cache to prevent stale values from a previous retry with a different account. if account.Platform == PlatformAnthropic && c != nil { - policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model) if policy.blockErr != nil { return nil, policy.blockErr } @@ -3849,19 +3996,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { - // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + systemRewritten := false if !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemIncludesClaudeCodePrompt(parsed.System) { - body = injectClaudeCodePrompt(body, parsed.System) + body = rewriteSystemForNonClaudeCode(body, parsed.System) + systemRewritten = true } - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为); + // 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。 + // 两种情况下 enforceCacheControlLimit 都会兜底处理上限。 + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} if s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil && fp != nil { // metadata 透传开启时跳过 metadata 注入 - _, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx) + _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx) if !mimicMPT { if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { normalizeOpts.injectMetadata = true @@ -5407,9 +5559,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) var fingerprint *Fingerprint - enableFP, enableMPT := true, false + enableFP, enableMPT, enableCCH := true, false, false if s.settingService != nil { - enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx) } if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) @@ -5436,6 +5588,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // 同步 billing header cc_version 与实际发送的 User-Agent 版本 + if fingerprint != nil { + body = syncBillingHeaderVersion(body, fingerprint.UserAgent) + } + // CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后) + if enableCCH { + body = signBillingHeaderCCH(body) + } + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -5476,9 +5637,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // Build effective drop set: merge static defaults with dynamic beta policy filter rules - policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID) effectiveDropSet := mergeDropSets(policyFilterSet) - effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { @@ -5489,11 +5649,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex applyClaudeCodeMimicHeaders(req, reqStream) incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - // Match real Claude CLI traffic (per mitmproxy reports): - // messages requests typically use only oauth + interleaved-thinking. - // Also drop claude-code beta if a downstream client added it. + // Claude Code OAuth credentials are scoped to Claude Code. + // Non-haiku models MUST include claude-code beta for Anthropic to recognize + // this as a legitimate Claude Code request; without it, the request is + // rejected as third-party ("out of extra usage"). + // Haiku models are exempt from third-party detection and don't need it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) + if !strings.Contains(strings.ToLower(modelID), "haiku") { + requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking} + } + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") @@ -5716,7 +5881,7 @@ type betaPolicyResult struct { } // evaluateBetaPolicy loads settings once and evaluates all rules against the given request. -func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult { if s.settingService == nil { return betaPolicyResult{} } @@ -5731,10 +5896,11 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } - switch rule.Action { + effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) + switch effectiveAction { case BetaPolicyActionBlock: if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { - msg := rule.ErrorMessage + msg := effectiveErrMsg if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -5776,7 +5942,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet" // In the /v1/messages path, Forward() evaluates the policy first and caches the result; // buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this // evaluates on demand (one DB call). -func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} { if c != nil { if v, ok := c.Get(betaPolicyFilterSetKey); ok { if fs, ok := v.(map[string]struct{}); ok { @@ -5784,7 +5950,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont } } } - return s.evaluateBetaPolicy(ctx, "", account).filterSet + return s.evaluateBetaPolicy(ctx, "", account, model).filterSet } // betaPolicyScopeMatches checks whether a rule's scope matches the current account type. @@ -5803,6 +5969,33 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { } } +// matchModelWhitelist checks if a model matches any pattern in the whitelist. +// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching. +func matchModelWhitelist(model string, whitelist []string) bool { + for _, pattern := range whitelist { + if matchModelPattern(pattern, model) { + return true + } + } + return false +} + +// resolveRuleAction determines the effective action and error message for a rule given the request model. +// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally. +// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others. +func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) { + if len(rule.ModelWhitelist) == 0 { + return rule.Action, rule.ErrorMessage + } + if matchModelWhitelist(model, rule.ModelWhitelist) { + return rule.Action, rule.ErrorMessage + } + if rule.FallbackAction != "" { + return rule.FallbackAction, rule.FallbackErrorMessage + } + return BetaPolicyActionPass, "" // default fallback: pass (fail-open) +} + // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) @@ -5849,7 +6042,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( modelID string, ) ([]string, error) { // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) - policy := s.evaluateBetaPolicy(ctx, betaHeader, account) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID) if policy.blockErr != nil { return nil, policy.blockErr } @@ -5861,7 +6054,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → // 如果不做此检查,block 规则会被绕过。 - if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil { return nil, blockErr } @@ -5870,7 +6063,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 // 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 -func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError { if s.settingService == nil || len(tokens) == 0 { return nil } @@ -5882,14 +6075,15 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke isBedrock := account.IsBedrock() tokenSet := buildBetaTokenSet(tokens) for _, rule := range settings.Rules { - if rule.Action != BetaPolicyActionBlock { + effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) + if effectiveAction != BetaPolicyActionBlock { continue } if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } if _, present := tokenSet[rule.BetaToken]; present { - msg := rule.ErrorMessage + msg := effectiveErrMsg if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -7146,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool { return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() } -// postUsageBilling 统一处理使用量记录后的扣费逻辑: -// - 订阅/余额扣费 -// - API Key 配额更新 -// - API Key 限速用量更新 -// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) +// postUsageBilling is the legacy fallback billing path used when the unified +// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply +// for atomic billing. This path only runs in tests or degraded mode. func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { billingCtx, cancel := detachedBillingContext(ctx) defer cancel() cost := p.Cost - // 1. 订阅 / 余额扣费 if p.IsSubscriptionBill { if cost.TotalCost > 0 { if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } - deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) } } else { if cost.ActualCost > 0 { if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) } - deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) } } - // 2. API Key 配额 if p.shouldDeductAPIKeyQuota() { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) } } - // 3. API Key 限速用量 if p.shouldUpdateRateLimits() { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) } } - // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) if p.shouldUpdateAccountQuota() { accountCost := cost.TotalCost * p.AccountRateMultiplier if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { @@ -7196,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } } - finalizePostUsageBilling(p, deps) + // NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing + // cache updates. The legacy path does DB writes directly; the finalize path + // does cache queue + notifications. Notifications are dispatched separately + // by the caller after recording the usage log. } func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { @@ -7250,9 +7439,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.CacheCreationTokens = usageLog.CacheCreationTokens cmd.CacheReadTokens = usageLog.CacheReadTokens cmd.ImageCount = usageLog.ImageCount - if usageLog.MediaType != nil { - cmd.MediaType = *usageLog.MediaType - } if usageLog.ServiceTier != nil { cmd.ServiceTier = *usageLog.ServiceTier } @@ -7315,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog } } - finalizePostUsageBilling(p, deps) + finalizePostUsageBilling(p, deps, result) return true, nil } -func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { +func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { if p == nil || p.Cost == nil || deps == nil { return } @@ -7338,22 +7524,82 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) - // Balance low notification — use real-time balance from billing cache (not stale snapshot) - if !p.IsSubscriptionBill && p.Cost.ActualCost > 0 && p.User != nil && deps.balanceNotifyService != nil { - oldBalance := p.User.Balance // fallback to snapshot - if deps.billingCacheService != nil { - if realBalance, err := deps.billingCacheService.GetUserBalance(context.Background(), p.User.ID); err == nil { - oldBalance = realBalance + p.Cost.ActualCost // DB already deducted, reconstruct pre-deduction balance - } + // Notification checks run async — all parameters are already captured, + // no dependency on the request context or upstream connection. + go notifyBalanceLow(p, deps, result) + go notifyAccountQuota(p, deps, result) +} + +// notifyBalanceLow sends balance low notification after deduction. +// When result.NewBalance is available (from DB transaction RETURNING), it is used directly +// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races. +func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { + defer func() { + if r := recover(); r != nil { + slog.Error("panic in notifyBalanceLow", "recover", r) } - deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost) + }() + if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil { + slog.Debug("notifyBalanceLow: skipped", + "is_subscription", p.IsSubscriptionBill, + "actual_cost", p.Cost.ActualCost, + "user_nil", p.User == nil, + "service_nil", deps.balanceNotifyService == nil, + ) + return } - // Account quota notification (use same cost formula as postUsageBilling) - if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil { - accountCost := p.Cost.TotalCost * p.AccountRateMultiplier - deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost) + oldBalance := resolveOldBalance(p, result) + slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction", + "user_id", p.User.ID, + "old_balance", oldBalance, + "cost", p.Cost.ActualCost, + "notify_enabled", p.User.BalanceNotifyEnabled, + "threshold", p.User.BalanceNotifyThreshold, + "result_has_new_balance", result != nil && result.NewBalance != nil, + ) + deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost) +} + +// resolveOldBalance returns the pre-deduction balance. +// Prefers the DB transaction result (newBalance + cost) over snapshot. +func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 { + if result != nil && result.NewBalance != nil { + return *result.NewBalance + p.Cost.ActualCost } + // Legacy fallback: snapshot balance from request context + return p.User.Balance +} + +// notifyAccountQuota sends account quota threshold notification after increment. +// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly +// to avoid a separate DB read that may see stale or concurrently-modified data. +func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { + defer func() { + if r := recover(); r != nil { + slog.Error("panic in notifyAccountQuota", "recover", r) + } + }() + if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil { + slog.Debug("notifyAccountQuota: skipped", + "total_cost", p.Cost.TotalCost, + "account_nil", p.Account == nil, + "is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(), + "service_nil", deps.balanceNotifyService == nil, + ) + return + } + accountCost := p.Cost.TotalCost * p.AccountRateMultiplier + var quotaState *AccountQuotaState + if result != nil { + quotaState = result.QuotaState + } + slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement", + "account_id", p.Account.ID, + "account_cost", accountCost, + "has_quota_state", quotaState != nil, + ) + deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState) } func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { @@ -7422,11 +7668,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage // recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 type recordUsageOpts struct { - // ParsedRequest(可选,仅 Claude 路径传入) + // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) ParsedRequest *ParsedRequest // EnableClaudePath 启用 Claude 路径特有逻辑: - // - MediaType 字段写入使用日志 + // - Claude Max 缓存计费策略 EnableClaudePath bool // 长上下文计费(仅 Gemini 路径需要) @@ -7451,7 +7697,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu APIKeyService: input.APIKeyService, ChannelUsageFields: input.ChannelUsageFields, }, &recordUsageOpts{ - ParsedRequest: input.ParsedRequest, EnableClaudePath: true, }) } @@ -7517,6 +7762,7 @@ type recordUsageCoreInput struct { // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // opts 中的字段控制两者之间的差异行为: +// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result @@ -7583,13 +7829,10 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) if apiKey.GroupID != nil { - upstreamModel := result.UpstreamModel - if upstreamModel == "" { - upstreamModel = result.Model - } - usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, s.channelService, s.billingService, - account.ID, *apiKey.GroupID, upstreamModel, + applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService, + account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model, + // Anthropic's input_tokens excludes cache_read and cache_creation (billed separately); + // OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason. UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -7597,7 +7840,6 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage CacheReadTokens: result.Usage.CacheReadInputTokens, ImageOutputTokens: result.Usage.ImageOutputTokens, }, - 1, // requestCount cost.TotalCost, ) } @@ -7796,13 +8038,12 @@ func (s *GatewayService) buildRecordUsageLog( RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, - BillingMode: resolveBillingMode(opts, result, cost), + BillingMode: resolveBillingMode(result, cost), Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: optionalTrimmedStringPtr(result.ImageSize), - MediaType: resolveMediaType(opts, result), CacheTTLOverridden: cacheTTLOverridden, ChannelID: optionalInt64Ptr(input.ChannelID), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), @@ -7826,7 +8067,7 @@ func (s *GatewayService) buildRecordUsageLog( } // resolveBillingMode 根据计费结果和请求类型确定计费模式。 -func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { +func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { var mode string switch { case cost != nil && cost.BillingMode != "": @@ -7839,10 +8080,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost return &mode } -func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { - return nil -} - func optionalSubscriptionID(subscription *UserSubscription) *int64 { if subscription != nil { return &subscription.ID @@ -8349,9 +8586,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - ctEnableFP, ctEnableMPT := true, false + ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false if s.settingService != nil { - ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx) } var ctFingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { @@ -8369,6 +8606,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + // 同步 billing header cc_version 与实际发送的 User-Agent 版本 + if ctFingerprint != nil && ctEnableFP { + body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent) + } + if ctEnableCCH { + body = signBillingHeaderCCH(body) + } + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -8409,7 +8654,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules - ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID)) // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { diff --git a/backend/internal/service/payment_config_plans_validation_test.go b/backend/internal/service/payment_config_plans_validation_test.go index 9a2d8716..efdbdb10 100644 --- a/backend/internal/service/payment_config_plans_validation_test.go +++ b/backend/internal/service/payment_config_plans_validation_test.go @@ -128,3 +128,66 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) { err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil}) require.NoError(t, err) } + +// --- validatePlanPatch: other fields --- + +func ptrStr(s string) *string { return &s } +func ptrInt(i int) *int { return &i } +func ptrInt64(i int64) *int64 { return &i } +func ptrFloat(f float64) *float64 { return &f } + +func TestValidatePlanPatch_EmptyName(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")}) + require.Error(t, err) + require.Contains(t, err.Error(), "plan name") +} + +func TestValidatePlanPatch_ValidName(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_ZeroGroupID(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)}) + require.Error(t, err) + require.Contains(t, err.Error(), "group") +} + +func TestValidatePlanPatch_NegativePrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)}) + require.Error(t, err) + require.Contains(t, err.Error(), "price") +} + +func TestValidatePlanPatch_ZeroPrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)}) + require.Error(t, err) + require.Contains(t, err.Error(), "price") +} + +func TestValidatePlanPatch_ValidPrice(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)}) + require.Error(t, err) + require.Contains(t, err.Error(), "validity days") +} + +func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")}) + require.Error(t, err) + require.Contains(t, err.Error(), "validity unit") +} + +func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")}) + require.NoError(t, err) +} + +func TestValidatePlanPatch_AllNil(t *testing.T) { + err := validatePlanPatch(UpdatePlanRequest{}) + require.NoError(t, err) +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 7602d162..a7724a5a 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -330,6 +330,7 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str Code: code, Attempts: 0, CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(verifyCodeTTL), } if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { return fmt.Errorf("save verify code: %w", err) @@ -370,7 +371,11 @@ func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) } if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { data.Attempts++ - if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { + remaining := time.Until(data.ExpiresAt) + if remaining <= 0 { + return ErrInvalidVerifyCode + } + if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil { slog.Error("failed to update notify verify code attempts", "email", email, "error", err) } if data.Attempts >= maxVerifyCodeAttempts { @@ -418,11 +423,17 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email } filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails)) + found := false for _, e := range user.BalanceNotifyExtraEmails { - if !strings.EqualFold(e.Email, email) { + if strings.EqualFold(e.Email, email) { + found = true + } else { filtered = append(filtered, e) } } + if !found { + return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found") + } user.BalanceNotifyExtraEmails = filtered return s.userRepo.Update(ctx, user) } diff --git a/frontend/src/components/account/QuotaDimensionRow.vue b/frontend/src/components/account/QuotaDimensionRow.vue index 1406faa9..e7fe2d0b 100644 --- a/frontend/src/components/account/QuotaDimensionRow.vue +++ b/frontend/src/components/account/QuotaDimensionRow.vue @@ -1,6 +1,7 @@ diff --git a/frontend/src/components/account/QuotaLimitCard.vue b/frontend/src/components/account/QuotaLimitCard.vue index 77e437a8..68a68f29 100644 --- a/frontend/src/components/account/QuotaLimitCard.vue +++ b/frontend/src/components/account/QuotaLimitCard.vue @@ -2,6 +2,7 @@ import { ref, watch, computed } from 'vue' import { useI18n } from 'vue-i18n' import QuotaDimensionRow from './QuotaDimensionRow.vue' +import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account' const { t } = useI18n() @@ -9,22 +10,22 @@ const props = withDefaults(defineProps<{ totalLimit: number | null dailyLimit: number | null weeklyLimit: number | null - dailyResetMode: 'rolling' | 'fixed' | null + dailyResetMode: QuotaResetMode | null dailyResetHour: number | null - weeklyResetMode: 'rolling' | 'fixed' | null + weeklyResetMode: QuotaResetMode | null weeklyResetDay: number | null weeklyResetHour: number | null resetTimezone: string | null quotaNotifyGlobalEnabled?: boolean quotaNotifyDailyEnabled?: boolean | null quotaNotifyDailyThreshold?: number | null - quotaNotifyDailyThresholdType?: string | null + quotaNotifyDailyThresholdType?: QuotaThresholdType | null quotaNotifyWeeklyEnabled?: boolean | null quotaNotifyWeeklyThreshold?: number | null - quotaNotifyWeeklyThresholdType?: string | null + quotaNotifyWeeklyThresholdType?: QuotaThresholdType | null quotaNotifyTotalEnabled?: boolean | null quotaNotifyTotalThreshold?: number | null - quotaNotifyTotalThresholdType?: string | null + quotaNotifyTotalThresholdType?: QuotaThresholdType | null }>(), { quotaNotifyGlobalEnabled: false, quotaNotifyDailyEnabled: null, @@ -42,21 +43,21 @@ const emit = defineEmits<{ 'update:totalLimit': [value: number | null] 'update:dailyLimit': [value: number | null] 'update:weeklyLimit': [value: number | null] - 'update:dailyResetMode': [value: 'rolling' | 'fixed' | null] + 'update:dailyResetMode': [value: QuotaResetMode | null] 'update:dailyResetHour': [value: number | null] - 'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null] + 'update:weeklyResetMode': [value: QuotaResetMode | null] 'update:weeklyResetDay': [value: number | null] 'update:weeklyResetHour': [value: number | null] 'update:resetTimezone': [value: string | null] 'update:quotaNotifyDailyEnabled': [value: boolean | null] 'update:quotaNotifyDailyThreshold': [value: number | null] - 'update:quotaNotifyDailyThresholdType': [value: string | null] + 'update:quotaNotifyDailyThresholdType': [value: QuotaThresholdType | null] 'update:quotaNotifyWeeklyEnabled': [value: boolean | null] 'update:quotaNotifyWeeklyThreshold': [value: number | null] - 'update:quotaNotifyWeeklyThresholdType': [value: string | null] + 'update:quotaNotifyWeeklyThresholdType': [value: QuotaThresholdType | null] 'update:quotaNotifyTotalEnabled': [value: boolean | null] 'update:quotaNotifyTotalThreshold': [value: number | null] - 'update:quotaNotifyTotalThresholdType': [value: string | null] + 'update:quotaNotifyTotalThresholdType': [value: QuotaThresholdType | null] }>() const enabled = computed(() => @@ -89,11 +90,6 @@ watch(localEnabled, (val) => { } }) -// Whether any fixed mode is active (to show timezone selector) -const hasFixedMode = computed(() => - props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed' -) - // Common timezone options const timezoneOptions = [ 'UTC', 'Asia/Shanghai', 'Asia/Tokyo', 'Asia/Seoul', 'Asia/Singapore', 'Asia/Kolkata', @@ -102,18 +98,6 @@ const timezoneOptions = [ 'America/Sao_Paulo', 'Australia/Sydney', 'Pacific/Auckland', ] -// Compute GMT offset label (e.g. "GMT+8", "GMT-5") for a given IANA timezone. -function getTimezoneOffsetLabel(tz: string): string { - try { - const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' }) - const parts = dtf.formatToParts(new Date()) - const tzPart = parts.find(p => p.type === 'timeZoneName') - return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : '' - } catch { - return '' - } -} - // Hours for dropdown (0-23) const hourOptions = Array.from({ length: 24 }, (_, i) => i) @@ -197,6 +181,7 @@ const dailyFixedHint = computed(() => :hint-fixed="dailyFixedHint" :hour-options="hourOptions" :day-options="dayOptions" + :timezone-options="timezoneOptions" @update:limit="emit('update:dailyLimit', $event)" @update:notify-enabled="emit('update:quotaNotifyDailyEnabled', $event)" @update:notify-threshold="emit('update:quotaNotifyDailyThreshold', $event)" @@ -223,6 +208,7 @@ const dailyFixedHint = computed(() => :hint-fixed="weeklyFixedHint" :hour-options="hourOptions" :day-options="dayOptions" + :timezone-options="timezoneOptions" @update:limit="emit('update:weeklyLimit', $event)" @update:notify-enabled="emit('update:quotaNotifyWeeklyEnabled', $event)" @update:notify-threshold="emit('update:quotaNotifyWeeklyThreshold', $event)" @@ -233,14 +219,6 @@ const dailyFixedHint = computed(() => @update:reset-timezone="emit('update:resetTimezone', $event)" /> - -
- - -
- -import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE } from '@/constants/account' +import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE, type QuotaThresholdType } from '@/constants/account' defineProps<{ enabled: boolean | null threshold: number | null - thresholdType: string | null // "fixed" (default) or "percentage" + thresholdType: QuotaThresholdType | null }>() const emit = defineEmits<{ 'update:enabled': [value: boolean | null] 'update:threshold': [value: number | null] - 'update:thresholdType': [value: string | null] + 'update:thresholdType': [value: QuotaThresholdType | null] }>() @@ -43,7 +43,7 @@ const emit = defineEmits<{ /> - {{ getGroupNameById(gid) }} + {{ getGroupNameById(gid) }}

@@ -481,7 +481,7 @@ :key="accountId" class="inline-flex items-center gap-1 rounded-md border border-primary-300 bg-primary-50 px-2 py-0.5 text-xs dark:border-primary-700 dark:bg-primary-900/20" > - {{ getRuleAccountLabel(accountId) }} + {{ getRuleAccountLabel(accountId) }} @@ -595,7 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types' import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types' import type { AdminGroup, GroupPlatform } from '@/types' import type { Column } from '@/components/common/types' -import { platformTextClass } from '@/utils/platformColors' +import { platformTextClass, platformBadgeLightClass } from '@/utils/platformColors' import AppLayout from '@/components/layout/AppLayout.vue' import TablePageLayout from '@/components/layout/TablePageLayout.vue' import DataTable from '@/components/common/DataTable.vue' @@ -720,26 +720,6 @@ let abortController: AbortController | null = null // ── Platform config ── const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity'] -function getPlatformTextColor(platform: string): string { - switch (platform) { - case 'anthropic': return 'text-orange-600 dark:text-orange-400' - case 'openai': return 'text-emerald-600 dark:text-emerald-400' - case 'gemini': return 'text-blue-600 dark:text-blue-400' - case 'antigravity': return 'text-purple-600 dark:text-purple-400' - default: return 'text-gray-600 dark:text-gray-400' - } -} - -function getRateBadgeClass(platform: string): string { - switch (platform) { - case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400' - case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400' - case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400' - case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400' - default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400' - } -} - // ── Helpers ── function formatDate(value: string): string { if (!value) return '-' From 9c09bd19b479c562098a58a821f69c42d7214204 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 00:42:40 +0800 Subject: [PATCH 66/88] fix: websearch features_config cleanup and pricing rules validation - Fix web_search_emulation toggle: explicitly write false for disabled platforms instead of leaving stale true from cloned features_config - Extract validatePricingEntries from validateChannelConfig for reuse - Validate account_stats_pricing_rules[].pricing in both Create and Update paths (negative prices, bad intervals, missing per_request price) --- backend/internal/service/channel_service.go | 22 ++++++++++++++++++--- frontend/src/views/admin/ChannelsView.vue | 8 ++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index d0698f0f..aa5e2ceb 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -566,15 +566,21 @@ func ReplaceModelInBody(body []byte, newModel string) []byte { // validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。 // Create 和 Update 共用此函数,避免重复。 func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error { + if err := validatePricingEntries(pricing); err != nil { + return err + } + return validateNoConflictingMappings(mapping) +} + +// validatePricingEntries 校验定价条目(冲突检测 + 区间校验 + 计费模式校验), +// 同时用于主渠道定价和 account_stats_pricing_rules 的内部定价。 +func validatePricingEntries(pricing []ChannelModelPricing) error { if err := validateNoConflictingModels(pricing); err != nil { return err } if err := validatePricingIntervals(pricing); err != nil { return err } - if err := validateNoConflictingMappings(mapping); err != nil { - return err - } return validatePricingBillingMode(pricing) } @@ -684,6 +690,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { return nil, err } + for i, rule := range channel.AccountStatsPricingRules { + if err := validatePricingEntries(rule.Pricing); err != nil { + return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err) + } + } if err := s.repo.Create(ctx, channel); err != nil { return nil, fmt.Errorf("create channel: %w", err) @@ -712,6 +723,11 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { return nil, err } + for i, rule := range channel.AccountStatsPricingRules { + if err := validatePricingEntries(rule.Pricing); err != nil { + return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err) + } + } oldGroupIDs := s.getOldGroupIDs(ctx, id) diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 52d57d74..0b37a20d 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -1032,15 +1032,19 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[ } // Collect web_search_emulation (only anthropic platform supports it) + // Always write the key so that disabling in the UI correctly sets platform to false, + // rather than leaving a stale true value from the cloned features_config. const wsEmulation: Record = {} for (const section of form.platforms) { if (!section.enabled) continue - if (section.web_search_emulation && section.platform === 'anthropic') { - wsEmulation[section.platform] = true + if (section.platform === 'anthropic') { + wsEmulation[section.platform] = !!section.web_search_emulation } } if (Object.keys(wsEmulation).length > 0) { featuresConfig.web_search_emulation = wsEmulation + } else { + delete featuresConfig.web_search_emulation } return { group_ids, model_pricing, model_mapping, features_config: featuresConfig } From 0a4ece5f5be30db727bb9694588bfd0d0d99a26c Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 01:10:46 +0800 Subject: [PATCH 67/88] =?UTF-8?q?fix:=20audit=20round-3=20=E2=80=94=20prox?= =?UTF-8?q?y=20safety,=20intervals=20persistence,=20SMTP=20timeout,=20sort?= =?UTF-8?q?=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Skip websearch provider when ProxyID is set but proxy not found (prevent silent direct connection bypass) - Fix sortByStableRandomWeight: pair factors with items so sort.Slice swap keeps weights aligned - Allow empty platform in account_stats_pricing_rules (wildcard matching), only force anthropic default for main model_pricing - Add channel_account_stats_pricing_intervals table and repo layer support for interval-based pricing in account stats rules - calculateTokenStatsCost now uses interval pricing when available - Replace smtp.SendMail/tls.Dial with net.Dialer timeout (10s dial, 20s IO) to prevent goroutine leak on SMTP hang - Fix gofmt formatting issues - Web Search label: black text with red warning hint --- .../internal/handler/admin/channel_handler.go | 14 +++- backend/internal/pkg/websearch/manager.go | 17 +++-- .../channel_repo_account_stats_pricing.go | 74 +++++++++++++++++++ backend/internal/repository/email_cache.go | 10 +-- backend/internal/server/http.go | 6 ++ .../internal/service/account_stats_pricing.go | 31 ++++++-- .../service/balance_notify_service.go | 1 - backend/internal/service/email_service.go | 49 +++++++++++- .../internal/service/notify_email_entry.go | 1 - ...06_add_account_stats_pricing_intervals.sql | 19 +++++ frontend/src/views/admin/ChannelsView.vue | 4 +- 11 files changed, 199 insertions(+), 27 deletions(-) create mode 100644 backend/migrations/106_add_account_stats_pricing_intervals.sql diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 1a328551..88d27c47 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -249,9 +249,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe billingMode = service.BillingModeToken } platform := r.Platform - if platform == "" { - platform = service.PlatformAnthropic - } intervals := make([]service.PricingInterval, 0, len(r.Intervals)) for _, iv := range r.Intervals { intervals = append(intervals, service.PricingInterval{ @@ -349,6 +346,12 @@ func (h *ChannelHandler) Create(c *gin.Context) { } pricing := pricingRequestToService(req.ModelPricing) + // Main model_pricing requires a platform; default to anthropic for backward compatibility. + for i := range pricing { + if pricing[i].Platform == "" { + pricing[i].Platform = service.PlatformAnthropic + } + } var statsRules []service.AccountStatsPricingRule for i, r := range req.AccountStatsPricingRules { @@ -415,6 +418,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) + for i := range pricing { + if pricing[i].Platform == "" { + pricing[i].Platform = service.PlatformAnthropic + } + } input.ModelPricing = &pricing } if req.AccountStatsPricingRules != nil { diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index 27592459..61faa616 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -200,13 +200,20 @@ func sortByStableRandomWeight(items []weighted) { if len(items) <= 1 { return } - factors := make([]float64, len(items)) - for i, item := range items { - factors[i] = float64(item.weight) * (0.5 + rand.Float64()) + type entry struct { + item weighted + factor float64 } - sort.Slice(items, func(i, j int) bool { - return factors[i] > factors[j] + entries := make([]entry, len(items)) + for i, item := range items { + entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())} + } + sort.Slice(entries, func(i, j int) bool { + return entries[i].factor > entries[j].factor }) + for i, e := range entries { + items[i] = e.item + } } func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig { diff --git a/backend/internal/repository/channel_repo_account_stats_pricing.go b/backend/internal/repository/channel_repo_account_stats_pricing.go index ef8f5177..9e00fed8 100644 --- a/backend/internal/repository/channel_repo_account_stats_pricing.go +++ b/backend/internal/repository/channel_repo_account_stats_pricing.go @@ -96,6 +96,27 @@ func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Contex if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate account stats model pricing: %w", err) } + + // Load intervals for all pricing entries. + var allPricingIDs []int64 + for _, pricings := range pricingMap { + for _, p := range pricings { + allPricingIDs = append(allPricingIDs, p.ID) + } + } + if len(allPricingIDs) > 0 { + intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs) + if err != nil { + return nil, err + } + for ruleID, pricings := range pricingMap { + for i := range pricings { + pricings[i].Intervals = intervalsMap[pricings[i].ID] + } + pricingMap[ruleID] = pricings + } + } + return pricingMap, nil } @@ -166,5 +187,58 @@ func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID in if err != nil { return fmt.Errorf("insert account stats model pricing: %w", err) } + // Persist intervals (mirrors channel_pricing_intervals logic). + for i := range pricing.Intervals { + iv := &pricing.Intervals[i] + iv.PricingID = pricing.ID + if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil { + return err + } + } return nil } + +// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry. +func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error { + return tx.QueryRowContext(ctx, + `INSERT INTO channel_account_stats_pricing_intervals + (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel, + iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice, + iv.PerRequestPrice, iv.SortOrder, + ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt) +} + +// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries. +func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) { + if len(pricingIDs) == 0 { + return nil, nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT id, pricing_id, min_tokens, max_tokens, tier_label, + input_price, output_price, cache_write_price, cache_read_price, + per_request_price, sort_order, created_at, updated_at + FROM channel_account_stats_pricing_intervals + WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`, + pq.Array(pricingIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err) + } + defer func() { _ = rows.Close() }() + + result := make(map[int64][]service.PricingInterval) + for rows.Next() { + var iv service.PricingInterval + if err := rows.Scan( + &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel, + &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice, + &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan account stats pricing interval: %w", err) + } + result[iv.PricingID] = append(result[iv.PricingID], iv) + } + return result, rows.Err() +} diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index 0eb6bef1..96a23a8e 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -12,11 +12,11 @@ import ( ) const ( - verifyCodeKeyPrefix = "verify_code:" - notifyVerifyKeyPrefix = "notify_verify:" - passwordResetKeyPrefix = "password_reset:" - passwordResetSentAtKeyPrefix = "password_reset_sent:" - notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" + verifyCodeKeyPrefix = "verify_code:" + notifyVerifyKeyPrefix = "notify_verify:" + passwordResetKeyPrefix = "password_reset:" + passwordResetSentAtKeyPrefix = "password_reset_sent:" + notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" ) // verifyCodeKey generates the Redis key for email verification code. diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 5165b059..d203bab2 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -4,6 +4,7 @@ package server import ( "context" "log" + "log/slog" "net/http" "time" @@ -82,6 +83,11 @@ func ProvideRouter( pc.ProxyID = *p.ProxyID if u, ok := proxyURLs[*p.ProxyID]; ok { pc.ProxyURL = u + } else { + // Proxy configured but not found — skip this provider to prevent direct connection. + slog.Warn("websearch: proxy not found for provider, skipping", + "provider", p.Type, "proxy_id", *p.ProxyID) + continue } } configs = append(configs, pc) diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 8251dede..61c318d9 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -195,18 +195,33 @@ func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int } // calculateTokenStatsCost Token 计费。 +// If the pricing has intervals, find the matching interval by total token count +// and use its prices instead of the flat pricing fields. func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 { - deref := func(p *float64) float64 { - if p == nil { + p := pricing + if len(pricing.Intervals) > 0 { + totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens + if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil { + p = &ChannelModelPricing{ + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + } + } + } + deref := func(ptr *float64) float64 { + if ptr == nil { return 0 } - return *p + return *ptr } - cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) + - float64(tokens.OutputTokens)*deref(pricing.OutputPrice) + - float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) + - float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) + - float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice) + cost := float64(tokens.InputTokens)*deref(p.InputPrice) + + float64(tokens.OutputTokens)*deref(p.OutputPrice) + + float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) + + float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) + + float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice) if cost <= 0 { return nil } diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 5e9afcc8..5b7e413a 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -477,4 +477,3 @@ func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, account } return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay) } - diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index a94e0dde..425887cd 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "math/big" + "net" "net/smtp" "net/url" "strconv" @@ -152,6 +153,9 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) return s.SendEmailWithConfig(config, to, subject, body) } +const smtpDialTimeout = 10 * time.Second +const smtpIOTimeout = 20 * time.Second + // SendEmailWithConfig 使用指定配置发送邮件 func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error { // Sanitize all SMTP header fields to prevent header injection (CR/LF removal). @@ -173,7 +177,46 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host) } - return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg)) + return s.sendMailPlain(addr, auth, config.From, to, []byte(msg), config.Host) +} + +// sendMailPlain sends mail without TLS using a dialer with timeout. +func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error { + dialer := &net.Dialer{Timeout: smtpDialTimeout} + conn, err := dialer.Dial("tcp", addr) + if err != nil { + return fmt.Errorf("smtp dial: %w", err) + } + _ = conn.SetDeadline(time.Now().Add(smtpIOTimeout)) + defer func() { _ = conn.Close() }() + + client, err := smtp.NewClient(conn, host) + if err != nil { + return fmt.Errorf("new smtp client: %w", err) + } + defer func() { _ = client.Close() }() + + if err = client.Auth(auth); err != nil { + return fmt.Errorf("smtp auth: %w", err) + } + if err = client.Mail(from); err != nil { + return fmt.Errorf("smtp mail: %w", err) + } + if err = client.Rcpt(to); err != nil { + return fmt.Errorf("smtp rcpt: %w", err) + } + w, err := client.Data() + if err != nil { + return fmt.Errorf("smtp data: %w", err) + } + if _, err = w.Write(msg); err != nil { + return fmt.Errorf("write msg: %w", err) + } + if err = w.Close(); err != nil { + return fmt.Errorf("close writer: %w", err) + } + _ = client.Quit() + return nil } // sendMailTLS 使用TLS发送邮件 @@ -184,10 +227,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, MinVersion: tls.VersionTLS12, } - conn, err := tls.Dial("tcp", addr, tlsConfig) + dialer := &net.Dialer{Timeout: smtpDialTimeout} + conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) if err != nil { return fmt.Errorf("tls dial: %w", err) } + _ = conn.SetDeadline(time.Now().Add(smtpIOTimeout)) defer func() { _ = conn.Close() }() client, err := smtp.NewClient(conn, host) diff --git a/backend/internal/service/notify_email_entry.go b/backend/internal/service/notify_email_entry.go index d181200b..625185b2 100644 --- a/backend/internal/service/notify_email_entry.go +++ b/backend/internal/service/notify_email_entry.go @@ -79,4 +79,3 @@ func MarshalNotifyEmails(entries []NotifyEmailEntry) string { } return string(data) } - diff --git a/backend/migrations/106_add_account_stats_pricing_intervals.sql b/backend/migrations/106_add_account_stats_pricing_intervals.sql new file mode 100644 index 00000000..5ae10655 --- /dev/null +++ b/backend/migrations/106_add_account_stats_pricing_intervals.sql @@ -0,0 +1,19 @@ +-- Add intervals table for account stats pricing rules (mirrors channel_pricing_intervals). +CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_intervals ( + id BIGSERIAL PRIMARY KEY, + pricing_id BIGINT NOT NULL REFERENCES channel_account_stats_model_pricing(id) ON DELETE CASCADE, + min_tokens INT NOT NULL DEFAULT 0, + max_tokens INT, + tier_label VARCHAR(50), + input_price NUMERIC(20,12), + output_price NUMERIC(20,12), + cache_write_price NUMERIC(20,12), + cache_read_price NUMERIC(20,12), + per_request_price NUMERIC(20,12), + sort_order INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_account_stats_pricing_intervals_pricing_id + ON channel_account_stats_pricing_intervals (pricing_id); diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 0b37a20d..e4452b98 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -328,10 +328,10 @@

-
From b402c367d331c24b1507bdbc0596c06c933b1d55 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 01:38:42 +0800 Subject: [PATCH 68/88] fix: add opportunistic STARTTLS to sendMailPlain for 587 port compatibility smtp.SendMail automatically upgrades to STARTTLS when the server supports it. Our replacement sendMailPlain skipped this, causing credentials to be sent in plaintext on port 587. Add STARTTLS negotiation before Auth to restore the original security behavior. --- backend/internal/service/email_service.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 425887cd..9cfd3bbd 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -196,6 +196,14 @@ func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to strin } defer func() { _ = client.Close() }() + // Opportunistic STARTTLS: upgrade to encrypted connection if the server supports it. + // This mirrors the behavior of smtp.SendMail which we replaced for timeout support. + if ok, _ := client.Extension("STARTTLS"); ok { + if err = client.StartTLS(&tls.Config{ServerName: host, MinVersion: tls.VersionTLS12}); err != nil { + return fmt.Errorf("starttls: %w", err) + } + } + if err = client.Auth(auth); err != nil { return fmt.Errorf("smtp auth: %w", err) } From 9e0d12d3b03e6b1cc8cf96a0c59f6b0a288fab85 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 07:22:22 +0800 Subject: [PATCH 69/88] fix: show websearch API key visibility/copy buttons for saved providers The buttons were hidden because v-if only checked provider.api_key, which is always empty for saved providers (backend sanitizes it). Now also checks api_key_configured. Copy button is disabled when no actual key is available (only configured placeholder shown). --- backend/cmd/server/VERSION | 2 +- .../internal/handler/admin/setting_handler.go | 2 +- backend/internal/service/websearch_config.go | 22 ++++++++++++++++++ frontend/src/views/admin/SettingsView.vue | 23 ++++++++++++------- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 68dda295..5657b5e3 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.51 +0.1.112.3 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 2324cc70..2a87e95c 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1939,7 +1939,7 @@ func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, service.SanitizeWebSearchConfig(c.Request.Context(), cfg)) + response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg)) } // UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置 diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go index 5658cec3..239e882a 100644 --- a/backend/internal/service/websearch_config.go +++ b/backend/internal/service/websearch_config.go @@ -277,6 +277,28 @@ func TestWebSearch(ctx context.Context, query string) (*WebSearchTestResult, err }, nil } +// PopulateWebSearchUsage returns a copy with quota usage populated from Redis (api_key kept as-is). +func PopulateWebSearchUsage(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig { + if cfg == nil { + return nil + } + out := *cfg + out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers)) + + mgr := getWebSearchManager() + + for i, p := range cfg.Providers { + out.Providers[i] = p + out.Providers[i].APIKeyConfigured = p.APIKey != "" + + if mgr != nil { + used, _ := mgr.GetUsage(ctx, p.Type) + out.Providers[i].QuotaUsed = used + } + } + return &out +} + // SanitizeWebSearchConfig returns a copy with api_key fields masked and quota usage populated. func SanitizeWebSearchConfig(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig { if cfg == nil { diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 3ef1c0ba..12f67187 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1775,8 +1775,8 @@ @click.stop /> - - {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }} + + {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit > 0 ? provider.quota_limit : '∞' }} {{ t('admin.settings.webSearchEmulation.apiKeyConfigured') }} @@ -1797,10 +1797,10 @@ v-model="provider.api_key" :type="apiKeyVisible[pIdx] ? 'text' : 'password'" class="input w-full text-sm" - :class="provider.api_key ? 'pr-16' : ''" + :class="(provider.api_key || provider.api_key_configured) ? 'pr-16' : ''" :placeholder="provider.api_key_configured ? '••••••••' : t('admin.settings.webSearchEmulation.apiKeyPlaceholder')" /> -
+
-
+
{{ t('admin.settings.webSearchEmulation.quotaUsage') }}: -
+
- {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit }} +
+ {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit > 0 ? provider.quota_limit : '∞' }}
@@ -3164,9 +3167,13 @@ async function loadWebSearchConfig() { async function saveWebSearchConfig(): Promise { try { + const providers = webSearchConfig.providers.map((p: WebSearchProviderConfig) => ({ + ...p, + quota_limit: typeof p.quota_limit === 'number' && p.quota_limit > 0 ? p.quota_limit : 0, + })) await adminAPI.settings.updateWebSearchEmulationConfig({ enabled: webSearchConfig.enabled, - providers: webSearchConfig.providers as WebSearchProviderConfig[], + providers, }) return true } catch (err: unknown) { From 1e6912ea2e12beb052d1e0523aa321be619de57f Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 07:43:08 +0800 Subject: [PATCH 70/88] fix: gofmt formatting across all Go source files --- .../internal/handler/admin/setting_handler.go | 10 ++-- backend/internal/handler/dto/settings.go | 10 ++-- backend/internal/handler/dto/types.go | 8 +-- .../internal/payment/load_balancer_test.go | 2 +- .../internal/payment/provider/alipay_test.go | 6 +- backend/internal/service/account.go | 27 ++++++--- .../service/balance_notify_email_body_test.go | 18 +++--- backend/internal/service/domain_constants.go | 4 +- .../payment_config_plans_validation_test.go | 6 +- .../service/payment_config_providers.go | 59 +++++++++++++++---- .../service/payment_config_providers_test.go | 2 +- .../service/payment_config_service.go | 36 +++++------ .../service/setting_service_public_test.go | 2 +- .../service/setting_service_update_test.go | 4 +- backend/internal/service/settings_view.go | 12 ++-- backend/internal/service/usage_billing.go | 13 ++++ backend/internal/service/user.go | 2 +- backend/internal/service/user_service.go | 6 +- 18 files changed, 143 insertions(+), 84 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 2a87e95c..b50cad96 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -310,11 +310,11 @@ type UpdateSettingsRequest struct { EnableCCHSigning *bool `json:"enable_cch_signing"` // Balance low notification - BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"` - AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"` - AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"` + BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"` + AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"` + AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"` // Payment configuration (integrated into settings, full replace) PaymentEnabled *bool `json:"payment_enabled"` diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index d218490a..ef285a44 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -150,11 +150,11 @@ type SystemSettings struct { PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"` // Balance low notification - BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` - AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` - AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` + AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index afb782b0..1aab1dbb 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -19,11 +19,11 @@ type User struct { UpdatedAt time.Time `json:"updated_at"` // 余额不足通知 - BalanceNotifyEnabled bool `json:"balance_notify_enabled"` - BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` - BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` + BalanceNotifyEnabled bool `json:"balance_notify_enabled"` + BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"` - TotalRecharged float64 `json:"total_recharged"` + TotalRecharged float64 `json:"total_recharged"` APIKeys []APIKey `json:"api_keys,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"` diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go index 568b56a3..04b3c25b 100644 --- a/backend/internal/payment/load_balancer_test.go +++ b/backend/internal/payment/load_balancer_test.go @@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) { wantIDs: nil, }, { - name: "empty candidates returns empty", + name: "empty candidates returns empty", candidates: nil, paymentType: "alipay", orderAmount: 10, diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go index 1b9d66ba..7b0ce0d8 100644 --- a/backend/internal/payment/provider/alipay_test.go +++ b/backend/internal/payment/provider/alipay_test.go @@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) { errSubstr: "privateKey", }, { - name: "nil config map returns error for appId", - config: map[string]string{}, - wantErr: true, + name: "nil config map returns error for appId", + config: map[string]string{}, + wantErr: true, errSubstr: "appId", }, } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index cacfb240..52db3073 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1533,39 +1533,48 @@ func (a *Account) QuotaNotifyConfig(dim string) (enabled bool, threshold float64 } func (a *Account) GetQuotaNotifyDailyEnabled() bool { - e, _, _ := a.QuotaNotifyConfig(quotaDimDaily); return e + e, _, _ := a.QuotaNotifyConfig(quotaDimDaily) + return e } func (a *Account) GetQuotaNotifyDailyThreshold() float64 { - _, t, _ := a.QuotaNotifyConfig(quotaDimDaily); return t + _, t, _ := a.QuotaNotifyConfig(quotaDimDaily) + return t } func (a *Account) GetQuotaNotifyDailyThresholdType() string { - _, _, tt := a.QuotaNotifyConfig(quotaDimDaily); return tt + _, _, tt := a.QuotaNotifyConfig(quotaDimDaily) + return tt } func (a *Account) GetQuotaNotifyWeeklyEnabled() bool { - e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly); return e + e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly) + return e } func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 { - _, t, _ := a.QuotaNotifyConfig(quotaDimWeekly); return t + _, t, _ := a.QuotaNotifyConfig(quotaDimWeekly) + return t } func (a *Account) GetQuotaNotifyWeeklyThresholdType() string { - _, _, tt := a.QuotaNotifyConfig(quotaDimWeekly); return tt + _, _, tt := a.QuotaNotifyConfig(quotaDimWeekly) + return tt } func (a *Account) GetQuotaNotifyTotalEnabled() bool { - e, _, _ := a.QuotaNotifyConfig(quotaDimTotal); return e + e, _, _ := a.QuotaNotifyConfig(quotaDimTotal) + return e } func (a *Account) GetQuotaNotifyTotalThreshold() float64 { - _, t, _ := a.QuotaNotifyConfig(quotaDimTotal); return t + _, t, _ := a.QuotaNotifyConfig(quotaDimTotal) + return t } func (a *Account) GetQuotaNotifyTotalThresholdType() string { - _, _, tt := a.QuotaNotifyConfig(quotaDimTotal); return tt + _, _, tt := a.QuotaNotifyConfig(quotaDimTotal) + return tt } // nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点 diff --git a/backend/internal/service/balance_notify_email_body_test.go b/backend/internal/service/balance_notify_email_body_test.go index 9baf164e..aee5a5bc 100644 --- a/backend/internal/service/balance_notify_email_body_test.go +++ b/backend/internal/service/balance_notify_email_body_test.go @@ -65,15 +65,15 @@ func TestBuildBalanceLowEmailBody_NoRechargeURLOmitsButton(t *testing.T) { func TestBuildQuotaAlertEmailBody_AllFieldsPresent(t *testing.T) { s := &BalanceNotifyService{} body := s.buildQuotaAlertEmailBody( - 42, // accountID - "acc-foo", // accountName - "anthropic", // platform - "日限额 / Daily", // dimLabel - 750.50, // used - 1000.0, // limit - 249.50, // remaining - "$249.50", // thresholdDisplay - "MySite", // siteName + 42, // accountID + "acc-foo", // accountName + "anthropic", // platform + "日限额 / Daily", // dimLabel + 750.50, // used + 1000.0, // limit + 249.50, // remaining + "$249.50", // thresholdDisplay + "MySite", // siteName ) require.Contains(t, body, "MySite") diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 896ba59f..bdced29a 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -251,8 +251,8 @@ const ( SettingKeyEnableCCHSigning = "enable_cch_signing" // Balance Low Notification - SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 - SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD) + SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 + SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD) SettingKeyBalanceLowNotifyRechargeURL = "balance_low_notify_recharge_url" // 充值页面 URL // Account Quota Notification diff --git a/backend/internal/service/payment_config_plans_validation_test.go b/backend/internal/service/payment_config_plans_validation_test.go index efdbdb10..bcbe901f 100644 --- a/backend/internal/service/payment_config_plans_validation_test.go +++ b/backend/internal/service/payment_config_plans_validation_test.go @@ -131,9 +131,9 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) { // --- validatePlanPatch: other fields --- -func ptrStr(s string) *string { return &s } -func ptrInt(i int) *int { return &i } -func ptrInt64(i int64) *int64 { return &i } +func ptrStr(s string) *string { return &s } +func ptrInt(i int) *int { return &i } +func ptrInt64(i int64) *int64 { return &i } func ptrFloat(f float64) *float64 { return &f } func TestValidatePlanPatch_EmptyName(t *testing.T) { diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 10181914..0c71ab29 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db // ProviderInstanceResponse is the API response for a provider instance. type ProviderInstanceResponse struct { - ID int64 `json:"id"` - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Limits string `json:"limits"` - Enabled bool `json:"enabled"` - RefundEnabled bool `json:"refund_enabled"` - SortOrder int `json:"sort_order"` - PaymentMode string `json:"payment_mode"` + ID int64 `json:"id"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Limits string `json:"limits"` + Enabled bool `json:"enabled"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` + SortOrder int `json:"sort_order"` + PaymentMode string `json:"payment_mode"` } // ListProviderInstancesWithConfig returns provider instances with decrypted config. @@ -46,8 +47,8 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte resp := ProviderInstanceResponse{ ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, - Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder, - PaymentMode: inst.PaymentMode, + Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund, + SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } resp.Config, err = s.decryptAndMaskConfig(inst.Config) if err != nil { @@ -110,10 +111,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err != nil { return nil, err } + allowUserRefund := req.AllowUserRefund && req.RefundEnabled return s.entClient.PaymentProviderInstance.Create(). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). + SetAllowUserRefund(allowUserRefund). Save(ctx) } @@ -221,6 +224,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.RefundEnabled != nil { u.SetRefundEnabled(*req.RefundEnabled) + // Cascade: turning off refund_enabled also disables allow_user_refund + if !*req.RefundEnabled { + u.SetAllowUserRefund(false) + } + } + if req.AllowUserRefund != nil { + // Only allow enabling when refund_enabled is true + if *req.AllowUserRefund { + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err == nil && inst.RefundEnabled { + u.SetAllowUserRefund(true) + } + } else { + u.SetAllowUserRefund(false) + } } if req.PaymentMode != nil { u.SetPaymentMode(*req.PaymentMode) @@ -228,6 +246,23 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in return u.Save(ctx) } +// GetUserRefundEligibleInstanceIDs returns provider instance IDs that allow user refund. +func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.AllowUserRefundEQ(true), + paymentproviderinstance.RefundEnabledEQ(true), + ).Select(paymentproviderinstance.FieldID).All(ctx) + if err != nil { + return nil, err + } + ids := make([]string, 0, len(instances)) + for _, inst := range instances { + ids = append(ids, strconv.FormatInt(int64(inst.ID), 10)) + } + return ids, nil +} + func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) { inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) if err != nil { diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index e71eb9f5..2aaa874f 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -101,7 +101,7 @@ func TestIsSensitiveConfigField(t *testing.T) { t.Parallel() tests := []struct { - field string + field string wantSen bool }{ // Sensitive fields (contain key/secret/private/password/pkey patterns) diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 9042c3ab..cce31f4d 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -105,26 +105,28 @@ type MethodLimitsResponse struct { } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled bool `json:"enabled"` - PaymentMode string `json:"payment_mode"` - SortOrder int `json:"sort_order"` - Limits string `json:"limits"` - RefundEnabled bool `json:"refund_enabled"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled bool `json:"enabled"` + PaymentMode string `json:"payment_mode"` + SortOrder int `json:"sort_order"` + Limits string `json:"limits"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` } type UpdateProviderInstanceRequest struct { - Name *string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled *bool `json:"enabled"` - PaymentMode *string `json:"payment_mode"` - SortOrder *int `json:"sort_order"` - Limits *string `json:"limits"` - RefundEnabled *bool `json:"refund_enabled"` + Name *string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled *bool `json:"enabled"` + PaymentMode *string `json:"payment_mode"` + SortOrder *int `json:"sort_order"` + Limits *string `json:"limits"` + RefundEnabled *bool `json:"refund_enabled"` + AllowUserRefund *bool `json:"allow_user_refund"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 6dfa627c..5cf1e860 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -66,7 +66,7 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) { repo := &settingPublicRepoStub{ values: map[string]string{ - SettingKeyTableDefaultPageSize: "50", + SettingKeyTableDefaultPageSize: "50", SettingKeyTablePageSizeOptions: "[20,50,100]", }, } diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index 28c7ad02..e62218b4 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -208,7 +208,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { svc := NewSettingService(repo, &config.Config{}) err := svc.UpdateSettings(context.Background(), &SystemSettings{ - TableDefaultPageSize: 50, + TableDefaultPageSize: 50, TablePageSizeOptions: []int{20, 50, 100}, }) require.NoError(t, err) @@ -216,7 +216,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions]) err = svc.UpdateSettings(context.Background(), &SystemSettings{ - TableDefaultPageSize: 1000, + TableDefaultPageSize: 1000, TablePageSizeOptions: []int{20, 100}, }) require.NoError(t, err) diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 57f3746a..ec20fe0a 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -108,8 +108,8 @@ type SystemSettings struct { EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) // Balance low notification - BalanceLowNotifyEnabled bool - BalanceLowNotifyThreshold float64 + BalanceLowNotifyEnabled bool + BalanceLowNotifyThreshold float64 BalanceLowNotifyRechargeURL string // Account quota notification @@ -155,10 +155,10 @@ type PublicSettings struct { OIDCOAuthProviderName string Version string - BalanceLowNotifyEnabled bool - AccountQuotaNotifyEnabled bool - BalanceLowNotifyThreshold float64 - BalanceLowNotifyRechargeURL string + BalanceLowNotifyEnabled bool + AccountQuotaNotifyEnabled bool + BalanceLowNotifyThreshold float64 + BalanceLowNotifyRechargeURL string } // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) diff --git a/backend/internal/service/usage_billing.go b/backend/internal/service/usage_billing.go index 73b05743..30495624 100644 --- a/backend/internal/service/usage_billing.go +++ b/backend/internal/service/usage_billing.go @@ -100,9 +100,22 @@ func valueOrZero(v *int64) int64 { return *v } +// AccountQuotaState holds the post-increment quota state returned by the DB transaction. +// All values are post-update (i.e., already include the increment). +type AccountQuotaState struct { + TotalUsed float64 + TotalLimit float64 + DailyUsed float64 + DailyLimit float64 + WeeklyUsed float64 + WeeklyLimit float64 +} + type UsageBillingApplyResult struct { Applied bool APIKeyQuotaExhausted bool + NewBalance *float64 // post-deduction balance (nil = no balance deduction) + QuotaState *AccountQuotaState // post-increment quota state (nil = no quota increment) } type UsageBillingRepository interface { diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index d3d8c954..59f8aa6b 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -32,7 +32,7 @@ type User struct { // 余额不足通知 BalanceNotifyEnabled bool - BalanceNotifyThresholdType string // "fixed" (default) | "percentage" + BalanceNotifyThresholdType string // "fixed" (default) | "percentage" BalanceNotifyThreshold *float64 BalanceNotifyExtraEmails []NotifyEmailEntry TotalRecharged float64 diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a7724a5a..3490e804 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -13,9 +13,9 @@ import ( ) var ( - ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") - ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") - ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") + ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") + ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") ) From 7c7292935e8cefb4f2a2ebbf732bd923cb55c8e7 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 08:03:27 +0800 Subject: [PATCH 71/88] feat: websearch quota enhancements and balance notify hint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - QuotaLimit changed to *int64 (null=unlimited, >0=limited) - Add reset-usage endpoint (POST /admin/settings/web-search-emulation/reset-usage) - Show quota usage in header always (collapsed and expanded) - Add reset quota button in expanded provider view - Quota input: empty=unlimited with ∞ placeholder, must be >0 if set - Add email verification hint on balance notify card --- backend/cmd/server/VERSION | 2 +- .../internal/handler/admin/setting_handler.go | 23 +++++++++- backend/internal/pkg/websearch/manager.go | 9 ++++ backend/internal/server/http.go | 9 +++- backend/internal/server/routes/admin.go | 1 + backend/internal/service/websearch_config.go | 15 +++++-- .../internal/service/websearch_config_test.go | 18 ++++---- frontend/src/api/admin/settings.ts | 11 ++++- .../user/profile/ProfileBalanceNotifyCard.vue | 1 + frontend/src/i18n/locales/en.ts | 9 +++- frontend/src/i18n/locales/zh.ts | 9 +++- frontend/src/views/admin/SettingsView.vue | 42 +++++++++++++++---- 12 files changed, 121 insertions(+), 28 deletions(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 5657b5e3..630554d9 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.112.3 +0.1.112.4 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index b50cad96..9b49150c 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1962,7 +1962,28 @@ func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) { response.ErrorFrom(c, err) return } - response.Success(c, service.SanitizeWebSearchConfig(c.Request.Context(), updated)) + response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated)) +} + +// ResetWebSearchUsage 重置指定 provider 的配额用量 +// POST /api/v1/admin/settings/web-search-emulation/reset-usage +func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) { + var req struct { + ProviderType string `json:"provider_type"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if req.ProviderType == "" { + response.BadRequest(c, "provider_type is required") + return + } + if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, nil) } // TestWebSearchEmulation 测试 Web Search 搜索 diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go index 61faa616..307aa1e9 100644 --- a/backend/internal/pkg/websearch/manager.go +++ b/backend/internal/pkg/websearch/manager.go @@ -447,6 +447,15 @@ func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 { return result } +// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0. +func (m *Manager) ResetUsage(ctx context.Context, providerType string) error { + if m.redis == nil { + return nil + } + key := quotaRedisKey(providerType) + return m.redis.Del(ctx, key).Err() +} + // --- Provider factory --- func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider { diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index d203bab2..023e40bb 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -73,7 +73,7 @@ func ProvideRouter( pc := websearch.ProviderConfig{ Type: p.Type, APIKey: p.APIKey, - QuotaLimit: p.QuotaLimit, + QuotaLimit: derefInt64(p.QuotaLimit), ExpiresAt: p.ExpiresAt, } if p.SubscribedAt != nil { @@ -141,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { // 不设置 ReadTimeout,因为大请求体可能需要较长时间读取 } } + +func derefInt64(p *int64) int64 { + if p == nil { + return 0 + } + return *p +} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 0a7b7a8b..9af0fd8e 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -411,6 +411,7 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig) adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig) adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation) + adminSettings.POST("/web-search-emulation/reset-usage", h.Admin.Setting.ResetWebSearchUsage) } } diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go index 239e882a..f528a35b 100644 --- a/backend/internal/service/websearch_config.go +++ b/backend/internal/service/websearch_config.go @@ -24,7 +24,7 @@ type WebSearchProviderConfig struct { Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses APIKeyConfigured bool `json:"api_key_configured"` // read-only mask - QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited + QuotaLimit *int64 `json:"quota_limit"` // nil = unlimited, >0 = limited SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current usage from Redis ProxyID *int64 `json:"proxy_id"` // optional proxy association @@ -52,8 +52,8 @@ func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error { if !validProviderTypes[p.Type] { return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type) } - if p.QuotaLimit < 0 { - return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i) + if p.QuotaLimit != nil && *p.QuotaLimit < 0 { + return fmt.Errorf("provider[%d]: quota_limit must be > 0 or null", i) } if seen[p.Type] { return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type) @@ -299,6 +299,15 @@ func PopulateWebSearchUsage(ctx context.Context, cfg *WebSearchEmulationConfig) return &out } +// ResetWebSearchUsage deletes the Redis quota key for the given provider type. +func ResetWebSearchUsage(ctx context.Context, providerType string) error { + mgr := getWebSearchManager() + if mgr == nil { + return fmt.Errorf("web search manager not initialized") + } + return mgr.ResetUsage(ctx, providerType) +} + // SanitizeWebSearchConfig returns a copy with api_key fields masked and quota usage populated. func SanitizeWebSearchConfig(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig { if cfg == nil { diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go index 4aea98b7..8cd50d0d 100644 --- a/backend/internal/service/websearch_config_test.go +++ b/backend/internal/service/websearch_config_test.go @@ -17,8 +17,8 @@ func TestValidateWebSearchConfig_Valid(t *testing.T) { cfg := &WebSearchEmulationConfig{ Enabled: true, Providers: []WebSearchProviderConfig{ - {Type: "brave", QuotaLimit: 1000}, - {Type: "tavily", QuotaLimit: 500}, + {Type: "brave", QuotaLimit: int64Ptr(1000)}, + {Type: "tavily", QuotaLimit: int64Ptr(500)}, }, } require.NoError(t, validateWebSearchConfig(cfg)) @@ -42,9 +42,9 @@ func TestValidateWebSearchConfig_InvalidType(t *testing.T) { func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) { cfg := &WebSearchEmulationConfig{ - Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}}, + Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: int64Ptr(-1)}}, } - require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be >= 0") + require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be > 0 or null") } func TestValidateWebSearchConfig_DuplicateType(t *testing.T) { @@ -57,9 +57,9 @@ func TestValidateWebSearchConfig_DuplicateType(t *testing.T) { require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type") } -func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) { +func TestValidateWebSearchConfig_NilQuotaLimit(t *testing.T) { cfg := &WebSearchEmulationConfig{ - Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}}, + Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: nil}}, } require.NoError(t, validateWebSearchConfig(cfg)) } @@ -92,7 +92,7 @@ func TestParseWebSearchConfigJSON_BackwardCompatibility(t *testing.T) { cfg := parseWebSearchConfigJSON(raw) require.True(t, cfg.Enabled) require.Len(t, cfg.Providers, 1) - require.Equal(t, int64(1000), cfg.Providers[0].QuotaLimit) + require.Equal(t, int64(1000), *cfg.Providers[0].QuotaLimit) } // --- SanitizeWebSearchConfig --- @@ -126,12 +126,12 @@ func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) { cfg := &WebSearchEmulationConfig{ Enabled: true, Providers: []WebSearchProviderConfig{ - {Type: "brave", APIKey: "secret", QuotaLimit: 1000}, + {Type: "brave", APIKey: "secret", QuotaLimit: int64Ptr(1000)}, }, } out := SanitizeWebSearchConfig(context.Background(), cfg) require.True(t, out.Enabled) - require.Equal(t, int64(1000), out.Providers[0].QuotaLimit) + require.Equal(t, int64(1000), *out.Providers[0].QuotaLimit) } func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) { diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 4b5eb242..aa1d0f82 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -502,7 +502,7 @@ export interface WebSearchProviderConfig { type: 'brave' | 'tavily' api_key: string api_key_configured: boolean - quota_limit: number + quota_limit: number | null subscribed_at: number | null quota_used?: number proxy_id: number | null @@ -547,6 +547,12 @@ export async function testWebSearchEmulation( return data } +export async function resetWebSearchUsage( + payload: { provider_type: string } +): Promise { + await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload) +} + export const settingsAPI = { getSettings, updateSettings, @@ -565,7 +571,8 @@ export const settingsAPI = { updateBetaPolicySettings, getWebSearchEmulationConfig, updateWebSearchEmulationConfig, - testWebSearchEmulation + testWebSearchEmulation, + resetWebSearchUsage } export default settingsAPI diff --git a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue index 3a84fd6b..c4d04153 100644 --- a/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue +++ b/frontend/src/components/user/profile/ProfileBalanceNotifyCard.vue @@ -48,6 +48,7 @@
+

{{ t('profile.balanceNotify.extraEmailsHint') }}

diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index c8acf6c0..9baddc43 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -914,6 +914,7 @@ export default { thresholdPlaceholder: 'Enter amount', systemDefault: 'System Default', extraEmails: 'Notification Emails', + extraEmailsHint: 'You must add and verify an email address to receive low balance alerts', primaryEmail: 'Primary', noExtraEmails: 'No extra notification emails', enterEmail: 'Enter email address', @@ -4435,10 +4436,14 @@ export default { copyApiKey: 'Copy', copied: 'Copied', quotaLimit: 'Quota Limit', - quotaLimitHint: '0 = unlimited', + quotaLimitHint: 'Leave empty for unlimited; must be > 0 if set', + quotaLimitMustBePositive: 'Quota limit must be greater than 0', subscribedAt: 'Subscribed At', - subscribedAtHint: 'Quota resets monthly from this date', + subscribedAtHint: 'Quota resets monthly from this date; leave empty to disable auto-reset', quotaUsage: 'Usage', + resetUsage: 'Reset', + resetUsageConfirm: 'Reset usage counter for this provider?', + resetUsageSuccess: 'Usage counter reset', proxy: 'Proxy', removeProvider: 'Remove', noProviders: 'No search providers configured', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 499ed9cb..af8da265 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -918,6 +918,7 @@ export default { thresholdPlaceholder: '输入金额', systemDefault: '系统默认值', extraEmails: '通知邮箱', + extraEmailsHint: '必须添加并验证邮箱后,余额不足时才能收到提醒邮件', primaryEmail: '主邮箱', noExtraEmails: '暂无额外通知邮箱', enterEmail: '输入邮箱地址', @@ -4597,10 +4598,14 @@ export default { copyApiKey: '复制', copied: '已复制', quotaLimit: '配额上限', - quotaLimitHint: '0 表示无限制', + quotaLimitHint: '留空表示无限制;填写时必须大于 0', + quotaLimitMustBePositive: '配额上限必须大于 0', subscribedAt: '订阅时间', - subscribedAtHint: '配额从此日期起每月自动重置', + subscribedAtHint: '配额从此日期起每月自动重置;留空则不自动重置', quotaUsage: '用量', + resetUsage: '重置', + resetUsageConfirm: '确定要重置此服务商的用量计数吗?', + resetUsageSuccess: '用量已重置', proxy: '代理', removeProvider: '删除', noProviders: '未配置搜索服务商', diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 12f67187..c57d2033 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1774,9 +1774,9 @@ class="w-36" @click.stop /> - - - {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit > 0 ? provider.quota_limit : '∞' }} + + + {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit != null && provider.quota_limit > 0 ? provider.quota_limit : '∞' }} {{ t('admin.settings.webSearchEmulation.apiKeyConfigured') }} @@ -1835,7 +1835,7 @@
- +

{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}

@@ -1853,7 +1853,7 @@
{{ t('admin.settings.webSearchEmulation.quotaUsage') }}: -
+
- {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit > 0 ? provider.quota_limit : '∞' }} + {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit != null && provider.quota_limit > 0 ? provider.quota_limit : '∞' }} +
@@ -3118,6 +3126,19 @@ function quotaPercentage(provider: WebSearchProviderConfig): number { return ((provider.quota_used ?? 0) / provider.quota_limit) * 100 } +async function resetWebSearchUsage(idx: number) { + const provider = webSearchConfig.providers[idx] + if (!provider) return + if (!confirm(t('admin.settings.webSearchEmulation.resetUsageConfirm'))) return + try { + await adminAPI.settings.resetWebSearchUsage({ provider_type: provider.type }) + provider.quota_used = 0 + appStore.showSuccess(t('admin.settings.webSearchEmulation.resetUsageSuccess')) + } catch (err: unknown) { + appStore.showError(extractApiErrorMessage(err, t('common.error'))) + } +} + async function copyApiKey(idx: number) { const key = webSearchConfig.providers[idx]?.api_key if (!key) { @@ -3167,9 +3188,16 @@ async function loadWebSearchConfig() { async function saveWebSearchConfig(): Promise { try { + for (const p of webSearchConfig.providers) { + const raw = p.quota_limit + if (raw != null && Number(raw) !== 0 && Number(raw) < 1) { + appStore.showError(t('admin.settings.webSearchEmulation.quotaLimitMustBePositive')) + return false + } + } const providers = webSearchConfig.providers.map((p: WebSearchProviderConfig) => ({ ...p, - quota_limit: typeof p.quota_limit === 'number' && p.quota_limit > 0 ? p.quota_limit : 0, + quota_limit: Number(p.quota_limit) > 0 ? Number(p.quota_limit) : null, })) await adminAPI.settings.updateWebSearchEmulationConfig({ enabled: webSearchConfig.enabled, From 9028d2085f12f5ea30fc94165af21b767e455643 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 08:42:28 +0800 Subject: [PATCH 72/88] test: add unit tests for billing, websearch, and notify systems Billing (25 tests): - CalculateCostUnified: nil resolver fallback, token/per_request/image modes - GetModelPricingWithChannel: nil/partial/full channel overrides - resolveAccountStatsCost: four-level priority chain integration tests WebSearch (18 tests): - PopulateWebSearchUsage: nil input, manager states, QuotaLimit nil/*int64 - ResetWebSearchUsage: nil manager error - Manager.ResetUsage: nil Redis - shouldEmulateWebSearch: full decision chain (8 scenarios) Notify (36 tests): - ParseNotifyEmails/MarshalNotifyEmails: old/new format, roundtrip - crossedDownward: boundary values, threshold semantics - checkQuotaDimCrossings: mixed dimensions, disabled/zero skip --- .../internal/pkg/websearch/manager_test.go | 8 + .../service/account_stats_pricing_test.go | 242 ++++++++++++++++ .../service/balance_notify_check_test.go | 224 +++++++++++++++ .../internal/service/billing_service_test.go | 120 ++++++++ .../service/billing_service_unified_test.go | 258 ++++++++++++++++++ .../gateway_websearch_emulation_test.go | 238 ++++++++++++++++ .../service/notify_email_entry_test.go | 156 +++++++++++ .../internal/service/websearch_config_test.go | 123 +++++++++ 8 files changed, 1369 insertions(+) create mode 100644 backend/internal/service/billing_service_unified_test.go create mode 100644 backend/internal/service/notify_email_entry_test.go diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go index a4beef68..cbcf1b76 100644 --- a/backend/internal/pkg/websearch/manager_test.go +++ b/backend/internal/pkg/websearch/manager_test.go @@ -313,3 +313,11 @@ func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) { require.NoError(t, err) require.NotNil(t, c) } + +// --- ResetUsage --- + +func TestManager_ResetUsage_NilRedis(t *testing.T) { + m := NewManager(nil, nil) + err := m.ResetUsage(context.Background(), "brave") + require.NoError(t, err) +} diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index 23409d5e..36e5eb74 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -3,7 +3,9 @@ package service import ( + "context" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -527,3 +529,243 @@ func TestTryModelFilePricing_WithCacheTokens(t *testing.T) { // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95 require.InDelta(t, 0.95, *result, 1e-12) } + +// --------------------------------------------------------------------------- +// resolveAccountStatsCost — integration tests covering the 4-level priority chain +// --------------------------------------------------------------------------- + +func TestResolveAccountStatsCost_NilChannelService(t *testing.T) { + result := resolveAccountStatsCost( + context.Background(), + nil, // channelService is nil + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 1, "claude-sonnet-4", + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) { + cs := newTestChannelServiceForStats(t, &Channel{ + ID: 1, + Status: StatusActive, + }, 1, "") + + result := resolveAccountStatsCost( + context.Background(), + cs, + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 1, "", // empty upstream model + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) { + // Group 99 is NOT in the cache, so GetChannelForGroup returns nil + cs := newTestChannelServiceForStats(t, &Channel{ + ID: 1, + Status: StatusActive, + }, 1, "") + + result := resolveAccountStatsCost( + context.Background(), + cs, + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 99, "claude-sonnet-4", // groupID 99 has no channel + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{10}, + Pricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"claude-sonnet-4"}, + InputPrice: testPtrFloat64(0.01), + OutputPrice: testPtrFloat64(0.02), + }, + }, + }, + }, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, // billingService not needed when custom rule hits + 1, 10, "claude-sonnet-4", + tokens, 1, 999.0, // totalCost ignored because custom rule hits + ) + require.NotNil(t, result) + // 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0 + require.InDelta(t, 2.0, *result, 1e-12) +} + +func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + tokens, 1, 0.75, // totalCost = 0.75 + ) + require.NotNil(t, result) + require.InDelta(t, 0.75, *result, 1e-12) +} + +func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + UsageTokens{}, 1, 0.0, // totalCost = 0 + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, // not enabled + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, bs, + 1, 10, "claude-sonnet-4", + tokens, 1, 999.0, // totalCost ignored + ) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + // BillingService with no pricing for the model + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{}) + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, bs, + 1, 10, "totally-unknown-model", + tokens, 1, 0.0, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, // billingService is nil + 1, 10, "claude-sonnet-4", + UsageTokens{InputTokens: 100}, 1, 0.0, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) { + // Both custom rule and ApplyPricingToAccountStats are configured; + // custom rule should take precedence. + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{10}, + Pricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"claude-sonnet-4"}, + InputPrice: testPtrFloat64(0.05), + }, + }, + }, + }, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins) + ) + require.NotNil(t, result) + // Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost) + require.InDelta(t, 5.0, *result, 1e-12) +} + +// --------------------------------------------------------------------------- +// helpers for resolveAccountStatsCost tests +// --------------------------------------------------------------------------- + +// newTestChannelServiceForStats creates a ChannelService with a single channel +// mapped to the given groupID, suitable for resolveAccountStatsCost tests. +func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService { + t.Helper() + cache := newEmptyChannelCache() + cache.channelByGroupID[groupID] = channel + cache.groupPlatform[groupID] = platform + cs := &ChannelService{} + cache.loadedAt = time.Now() + cs.cache.Store(cache) + return cs +} diff --git a/backend/internal/service/balance_notify_check_test.go b/backend/internal/service/balance_notify_check_test.go index 955f3129..7bb4cf9e 100644 --- a/backend/internal/service/balance_notify_check_test.go +++ b/backend/internal/service/balance_notify_check_test.go @@ -178,3 +178,227 @@ func TestGetSiteName_Configured(t *testing.T) { repo.data[SettingKeySiteName] = "My Site" require.Equal(t, "My Site", s.getSiteName(context.Background())) } + +// ---------- crossedDownward ---------- + +func TestCrossedDownward_CrossesBelow(t *testing.T) { + // oldBalance > threshold, newBalance < threshold → true + require.True(t, crossedDownward(100, 5, 10)) +} + +func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) { + // oldBalance > threshold, newBalance == threshold → false (not below) + require.False(t, crossedDownward(100, 10, 10)) +} + +func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) { + // oldBalance == threshold, newBalance < threshold → true + // (at-or-above → below counts as a crossing) + require.True(t, crossedDownward(10, 5, 10)) +} + +func TestCrossedDownward_AlreadyBelow(t *testing.T) { + // oldBalance < threshold → false (already below, no new crossing) + require.False(t, crossedDownward(5, 3, 10)) +} + +func TestCrossedDownward_BothAbove(t *testing.T) { + // oldBalance > threshold, newBalance > threshold → false (no crossing) + require.False(t, crossedDownward(100, 50, 10)) +} + +func TestCrossedDownward_ZeroThreshold(t *testing.T) { + // threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives + // Typical case: positive balances should not fire when threshold is 0. + require.False(t, crossedDownward(10, 5, 0)) + require.False(t, crossedDownward(0, 0, 0)) +} + +func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) { + // Edge case: newBalance goes negative with threshold=0. + require.True(t, crossedDownward(5, -1, 0)) +} + +func TestCrossedDownward_NegativeValues(t *testing.T) { + // Both already negative, threshold is positive → no crossing (already below). + require.False(t, crossedDownward(-5, -10, 10)) +} + +func TestCrossedDownward_LargeDecrement(t *testing.T) { + // A single large deduction crosses the threshold. + require.True(t, crossedDownward(1000, 0.5, 100)) +} + +func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) { + // A tiny deduction stays above threshold. + require.False(t, crossedDownward(100, 99.99, 10)) +} + +// ---------- checkQuotaDimCrossings ---------- + +func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // Empty dims → no crossing, no panic. + s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite") + s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: false, // disabled + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + // Disabled dimension should be skipped even if crossing would occur. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 0, // zero threshold + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + // Zero threshold → skipped. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger) + // currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 300, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger) + // currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 800, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200 + // Negative resolved threshold → skipped. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 1200, + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700 + // currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing. + dims := []quotaDim{ + { + name: quotaDimWeekly, + enabled: true, + threshold: 30, + thresholdType: thresholdTypePercentage, + currentUsed: 500, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // limit=0 → resolvedThreshold returns 0 → skipped. + dims := []quotaDim{ + { + name: quotaDimTotal, + enabled: true, + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 50, + limit: 0, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // dim1: no crossing (both below effective threshold) + // dim2: disabled (skipped) + // dim3: zero threshold (skipped) + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below + limit: 1000, + }, + { + name: quotaDimWeekly, + enabled: false, + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 900, + limit: 1000, + }, + { + name: quotaDimTotal, + enabled: true, + threshold: 0, + thresholdType: thresholdTypeFixed, + currentUsed: 500, + limit: 1000, + }, + } + // None should trigger. No panic expected. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 6f6c41ce..2cf134e2 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -718,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing. require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12) require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12) } + +// --------------------------------------------------------------------------- +// GetModelPricingWithChannel +// --------------------------------------------------------------------------- + +func TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal(t *testing.T) { + svc := newTestBillingService() + + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", nil) + require.NoError(t, err) + require.NotNil(t, pricing) + + // Should be identical to GetModelPricing + original, err := svc.GetModelPricing("claude-sonnet-4") + require.NoError(t, err) + require.InDelta(t, original.InputPricePerToken, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, original.OutputPricePerToken, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, original.CacheCreationPricePerToken, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, original.CacheReadPricePerToken, pricing.CacheReadPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideInputPriceOnly(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(99e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // InputPrice overridden (both normal and priority) + require.InDelta(t, 99e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 99e-6, pricing.InputPricePerTokenPriority, 1e-12) + + // OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideOutputPriceOnly(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + OutputPrice: testPtrFloat64(88e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // OutputPrice overridden + require.InDelta(t, 88e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 88e-6, pricing.OutputPricePerTokenPriority, 1e-12) + + // InputPrice unchanged (claude-sonnet-4 fallback = 3e-6) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_OverrideAllFields(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(10e-6), + OutputPrice: testPtrFloat64(20e-6), + CacheWritePrice: testPtrFloat64(5e-6), + CacheReadPrice: testPtrFloat64(1e-6), + ImageOutputPrice: testPtrFloat64(50e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + require.InDelta(t, 10e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 20e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 20e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12) + require.InDelta(t, 1e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 1e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.InDelta(t, 50e-6, pricing.ImageOutputPricePerToken, 1e-12) +} + +func TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + CacheWritePrice: testPtrFloat64(7e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h + require.InDelta(t, 7e-6, pricing.CacheCreationPricePerToken, 1e-12) + require.InDelta(t, 7e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 7e-6, pricing.CacheCreation1hPrice, 1e-12) +} + +func TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + CacheReadPrice: testPtrFloat64(2e-6), + } + pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing) + require.NoError(t, err) + + // CacheReadPrice should set both normal and priority + require.InDelta(t, 2e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 2e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) +} + +func TestGetModelPricingWithChannel_UnknownModelReturnsError(t *testing.T) { + svc := newTestBillingService() + + chPricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(1e-6), + } + pricing, err := svc.GetModelPricingWithChannel("totally-unknown-model", chPricing) + require.Error(t, err) + require.Nil(t, pricing) + require.Contains(t, err.Error(), "pricing not found") +} diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go new file mode 100644 index 00000000..694c3384 --- /dev/null +++ b/backend/internal/service/billing_service_unified_test.go @@ -0,0 +1,258 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// CalculateCostUnified +// --------------------------------------------------------------------------- + +func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + input := CostInput{ + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: nil, // no resolver + } + cost, err := svc.CalculateCostUnified(input) + require.NoError(t, err) + + // Should match the old-path result exactly + expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil) + require.NoError(t, err) + require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) + // BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil) + require.Empty(t, cost.BillingMode) +} + +func TestCalculateCostUnified_TokenMode(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + input := CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.5, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075 + expectedTotal := 1000*3e-6 + 500*15e-6 + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) + require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModeToken), cost.BillingMode) +} + +func TestCalculateCostUnified_PerRequestMode(t *testing.T) { + // Set up a ChannelService with a per-request pricing channel + cs := newTestChannelServiceWithCache(t, &channelCache{ + pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{ + {groupID: 1, model: "claude-sonnet-4"}: { + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.05), + }, + }, + channelByGroupID: map[int64]*Channel{ + 1: {ID: 1, Status: StatusActive}, + }, + groupPlatform: map[int64]string{1: ""}, + wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{}, + mappingByGroupModel: map[channelModelKey]string{}, + wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{}, + byID: map[int64]*Channel{}, + }) + + bs := newTestBillingService() + resolver := NewModelPricingResolver(cs, bs) + groupID := int64(1) + + input := CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + GroupID: &groupID, + Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50}, + RequestCount: 3, + RateMultiplier: 2.0, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // 3 requests * $0.05 = $0.15 + require.InDelta(t, 0.15, cost.TotalCost, 1e-10) + // ActualCost = 0.15 * 2.0 = 0.30 + require.InDelta(t, 0.30, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModePerRequest), cost.BillingMode) +} + +func TestCalculateCostUnified_ImageMode(t *testing.T) { + cs := newTestChannelServiceWithCache(t, &channelCache{ + pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{ + {groupID: 2, model: "gemini-image"}: { + BillingMode: BillingModeImage, + PerRequestPrice: testPtrFloat64(0.10), + }, + }, + channelByGroupID: map[int64]*Channel{ + 2: {ID: 2, Status: StatusActive}, + }, + groupPlatform: map[int64]string{2: ""}, + wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{}, + mappingByGroupModel: map[channelModelKey]string{}, + wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{}, + byID: map[int64]*Channel{}, + }) + + bs := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{}, + } + resolver := NewModelPricingResolver(cs, bs) + groupID := int64(2) + + input := CostInput{ + Ctx: context.Background(), + Model: "gemini-image", + GroupID: &groupID, + Tokens: UsageTokens{}, + RequestCount: 2, + RateMultiplier: 1.0, + Resolver: resolver, + } + cost, err := bs.CalculateCostUnified(input) + require.NoError(t, err) + require.NotNil(t, cost) + + // 2 * $0.10 = $0.20 + require.InDelta(t, 0.20, cost.TotalCost, 1e-10) + require.InDelta(t, 0.20, cost.ActualCost, 1e-10) + require.Equal(t, string(BillingModeImage), cost.BillingMode) +} + +func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + costZero, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 0, // should default to 1.0 + Resolver: resolver, + }) + require.NoError(t, err) + + costOne, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) +} + +func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + tokens := UsageTokens{InputTokens: 1000} + + costNeg, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: -5.0, + Resolver: resolver, + }) + require.NoError(t, err) + + costOne, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: tokens, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) +} + +func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + cost, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: UsageTokens{InputTokens: 100}, + RateMultiplier: 1.0, + Resolver: resolver, + }) + require.NoError(t, err) + require.Equal(t, "token", cost.BillingMode) +} + +func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) { + bs := newTestBillingService() + resolver := NewModelPricingResolver(nil, bs) + + // Pre-resolve with per_request mode to verify it's used instead of re-resolving + preResolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + DefaultPerRequestPrice: 0.07, + } + + cost, err := bs.CalculateCostUnified(CostInput{ + Ctx: context.Background(), + Model: "claude-sonnet-4", + Tokens: UsageTokens{InputTokens: 100}, + RequestCount: 2, + RateMultiplier: 1.0, + Resolver: resolver, + Resolved: preResolved, + }) + require.NoError(t, err) + require.NotNil(t, cost) + + // 2 * $0.07 = $0.14 + require.InDelta(t, 0.14, cost.TotalCost, 1e-10) + require.Equal(t, string(BillingModePerRequest), cost.BillingMode) +} + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +// newTestChannelServiceWithCache creates a ChannelService with a pre-populated +// cache snapshot, bypassing the repository layer entirely. +func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService { + t.Helper() + cs := &ChannelService{} + cache.loadedAt = time.Now() + cs.cache.Store(cache) + return cs +} diff --git a/backend/internal/service/gateway_websearch_emulation_test.go b/backend/internal/service/gateway_websearch_emulation_test.go index b606c748..de1f0014 100644 --- a/backend/internal/service/gateway_websearch_emulation_test.go +++ b/backend/internal/service/gateway_websearch_emulation_test.go @@ -1,8 +1,14 @@ +//go:build unit + package service import ( + "context" + "encoding/json" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" "github.com/stretchr/testify/require" ) @@ -140,3 +146,235 @@ func TestBuildTextSummary_NoResults(t *testing.T) { summary := buildTextSummary("test", nil) require.Contains(t, summary, "No search results found for: test") } + +// --- shouldEmulateWebSearch --- + +// webSearchToolBody is a valid request body with exactly one web_search tool. +var webSearchToolBody = []byte(`{"tools":[{"type":"web_search"}],"messages":[{"role":"user","content":"test"}]}`) + +// nonWebSearchToolBody is a request body without web_search tool. +var nonWebSearchToolBody = []byte(`{"tools":[{"type":"text_editor"}],"messages":[{"role":"user","content":"test"}]}`) + +// newAnthropicAPIKeyAccount creates a test Account with the given web search emulation mode. +func newAnthropicAPIKeyAccount(mode string) *Account { + return &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: mode}, + } +} + +// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled. +func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) { + webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{ + config: cfg, + expiresAt: time.Now().Add(10 * time.Minute).UnixNano(), + }) +} + +// clearGlobalWebSearchConfig resets the global cache to force re-read. +func clearGlobalWebSearchConfig() { + webSearchEmulationCache.Store((*cachedWebSearchEmulationConfig)(nil)) +} + +// newSettingServiceForWebSearchTest creates a SettingService with a mock repo pre-loaded with config. +func newSettingServiceForWebSearchTest(enabled bool) *SettingService { + repo := newMockSettingRepo() + cfg := &WebSearchEmulationConfig{ + Enabled: enabled, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "sk-test"}}, + } + data, _ := json.Marshal(cfg) + repo.data[SettingKeyWebSearchEmulationConfig] = string(data) + return NewSettingService(repo, &config.Config{}) +} + +// newChannelServiceWithCache creates a ChannelService with a pre-built cache containing the channel. +func newChannelServiceWithCache(groupID int64, ch *Channel) *ChannelService { + svc := &ChannelService{} + cache := &channelCache{ + channelByGroupID: map[int64]*Channel{groupID: ch}, + byID: map[int64]*Channel{ch.ID: ch}, + groupPlatform: map[int64]string{}, + loadedAt: time.Now(), + } + svc.cache.Store(cache) + return svc +} + +func TestShouldEmulateWebSearch_NilManager(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + settingSvc := newSettingServiceForWebSearchTest(true) + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_NotOnlyWebSearchTool(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + settingSvc := newSettingServiceForWebSearchTest(true) + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, nonWebSearchToolBody)) +} + +func TestShouldEmulateWebSearch_GlobalDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + // Global config disabled + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: false, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(false) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_AccountDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeDisabled) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_AccountEnabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeEnabled) + require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_ChannelEnabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + ch := &Channel{ + ID: 10, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: true}, + }, + } + channelSvc := newChannelServiceWithCache(42, ch) + svc := &GatewayService{settingService: settingSvc, channelService: channelSvc} + + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_ChannelDisabled(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + ch := &Channel{ + ID: 10, + Status: StatusActive, + FeaturesConfig: map[string]any{ + featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: false}, + }, + } + channelSvc := newChannelServiceWithCache(42, ch) + svc := &GatewayService{settingService: settingSvc, channelService: channelSvc} + + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_NilGroupID(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc} + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + // nil groupID + default mode → falls through to channel check → returns false + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody)) +} + +func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) { + mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + setGlobalWebSearchConfig(&WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}}, + }) + defer clearGlobalWebSearchConfig() + + settingSvc := newSettingServiceForWebSearchTest(true) + svc := &GatewayService{settingService: settingSvc, channelService: nil} + account := newAnthropicAPIKeyAccount(WebSearchModeDefault) + groupID := int64(42) + // nil channelService + default mode → returns false + require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody)) +} diff --git a/backend/internal/service/notify_email_entry_test.go b/backend/internal/service/notify_email_entry_test.go new file mode 100644 index 00000000..0f4bb12e --- /dev/null +++ b/backend/internal/service/notify_email_entry_test.go @@ -0,0 +1,156 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ---------- ParseNotifyEmails ---------- + +func TestParseNotifyEmails_EmptyString(t *testing.T) { + result := ParseNotifyEmails("") + require.Nil(t, result) +} + +func TestParseNotifyEmails_EmptyArray(t *testing.T) { + result := ParseNotifyEmails("[]") + require.Nil(t, result) +} + +func TestParseNotifyEmails_Null(t *testing.T) { + // "null" is valid JSON that unmarshals into a nil string slice. + // The old-format branch then returns an empty (non-nil) slice. + result := ParseNotifyEmails("null") + require.Empty(t, result) +} + +func TestParseNotifyEmails_WhitespaceOnly(t *testing.T) { + result := ParseNotifyEmails(" ") + require.Nil(t, result) +} + +func TestParseNotifyEmails_OldFormat(t *testing.T) { + raw := `["alice@example.com", "bob@example.com"]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + + require.Equal(t, "alice@example.com", result[0].Email) + require.False(t, result[0].Verified, "old format emails should default to unverified") + require.False(t, result[0].Disabled) + + require.Equal(t, "bob@example.com", result[1].Email) + require.False(t, result[1].Verified) + require.False(t, result[1].Disabled) +} + +func TestParseNotifyEmails_OldFormat_SkipsEmptyEntries(t *testing.T) { + raw := `["alice@example.com", "", " ", "bob@example.com"]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + require.Equal(t, "alice@example.com", result[0].Email) + require.Equal(t, "bob@example.com", result[1].Email) +} + +func TestParseNotifyEmails_NewFormat(t *testing.T) { + raw := `[{"email":"alice@example.com","verified":true,"disabled":false},{"email":"bob@example.com","verified":false,"disabled":true}]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 2) + + require.Equal(t, "alice@example.com", result[0].Email) + require.True(t, result[0].Verified) + require.False(t, result[0].Disabled) + + require.Equal(t, "bob@example.com", result[1].Email) + require.False(t, result[1].Verified) + require.True(t, result[1].Disabled) +} + +func TestParseNotifyEmails_NewFormat_SingleEntry(t *testing.T) { + raw := `[{"email":"solo@example.com","verified":true,"disabled":false}]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "solo@example.com", result[0].Email) + require.True(t, result[0].Verified) +} + +func TestParseNotifyEmails_InvalidJSON(t *testing.T) { + result := ParseNotifyEmails(`{not valid json`) + require.Nil(t, result) +} + +func TestParseNotifyEmails_InvalidJSONObject(t *testing.T) { + // A plain JSON object (not array) should return nil. + result := ParseNotifyEmails(`{"email":"a@b.com"}`) + require.Nil(t, result) +} + +func TestParseNotifyEmails_WhitespacePadding(t *testing.T) { + raw := ` ["padded@example.com"] ` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "padded@example.com", result[0].Email) +} + +// ---------- MarshalNotifyEmails ---------- + +func TestMarshalNotifyEmails_EmptySlice(t *testing.T) { + result := MarshalNotifyEmails([]NotifyEmailEntry{}) + require.Equal(t, "[]", result) +} + +func TestMarshalNotifyEmails_NilSlice(t *testing.T) { + result := MarshalNotifyEmails(nil) + require.Equal(t, "[]", result) +} + +func TestMarshalNotifyEmails_SingleEntry(t *testing.T) { + entries := []NotifyEmailEntry{ + {Email: "test@example.com", Verified: true, Disabled: false}, + } + result := MarshalNotifyEmails(entries) + require.Contains(t, result, `"email":"test@example.com"`) + require.Contains(t, result, `"verified":true`) + require.Contains(t, result, `"disabled":false`) + + // Round-trip: parsing the marshalled result should produce the original entries. + parsed := ParseNotifyEmails(result) + require.Len(t, parsed, 1) + require.Equal(t, entries[0], parsed[0]) +} + +func TestMarshalNotifyEmails_MultipleEntries(t *testing.T) { + entries := []NotifyEmailEntry{ + {Email: "a@example.com", Verified: true, Disabled: false}, + {Email: "b@example.com", Verified: false, Disabled: true}, + } + result := MarshalNotifyEmails(entries) + + // Round-trip verification. + parsed := ParseNotifyEmails(result) + require.Len(t, parsed, 2) + require.Equal(t, entries[0], parsed[0]) + require.Equal(t, entries[1], parsed[1]) +} + +func TestMarshalNotifyEmails_RoundTrip_NewFormat(t *testing.T) { + original := []NotifyEmailEntry{ + {Email: "x@example.com", Verified: true, Disabled: true}, + {Email: "y@example.com", Verified: false, Disabled: false}, + } + marshalled := MarshalNotifyEmails(original) + parsed := ParseNotifyEmails(marshalled) + require.Equal(t, original, parsed) +} + +// ---------- isOldStringArrayFormat (indirectly via ParseNotifyEmails) ---------- + +func TestParseNotifyEmails_MixedOldFormatWithWhitespace(t *testing.T) { + // Emails with leading/trailing whitespace in old format should be trimmed. + raw := `[" alice@example.com "]` + result := ParseNotifyEmails(raw) + require.Len(t, result, 1) + require.Equal(t, "alice@example.com", result[0].Email) +} diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go index 8cd50d0d..c5b96e01 100644 --- a/backend/internal/service/websearch_config_test.go +++ b/backend/internal/service/websearch_config_test.go @@ -1,9 +1,12 @@ +//go:build unit + package service import ( "context" "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" "github.com/stretchr/testify/require" ) @@ -141,3 +144,123 @@ func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) { _ = SanitizeWebSearchConfig(context.Background(), cfg) require.Equal(t, "secret", cfg.Providers[0].APIKey) } + +// --- PopulateWebSearchUsage --- + +func TestPopulateWebSearchUsage_NilInput(t *testing.T) { + require.Nil(t, PopulateWebSearchUsage(context.Background(), nil)) +} + +func TestPopulateWebSearchUsage_NoManager_QuotaUsedZero(t *testing.T) { + // Ensure no global manager is set + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Enabled: true, + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.NotNil(t, out) + require.Len(t, out.Providers, 1) + require.Equal(t, int64(0), out.Providers[0].QuotaUsed) +} + +func TestPopulateWebSearchUsage_APIKeyConfigured_True(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key"}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.True(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_APIKeyConfigured_False(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: ""}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.False(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_NilQuotaLimit(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: nil}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.Nil(t, out.Providers[0].QuotaLimit) +} + +func TestPopulateWebSearchUsage_NonNilQuotaLimit(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(500)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.NotNil(t, out.Providers[0].QuotaLimit) + require.Equal(t, int64(500), *out.Providers[0].QuotaLimit) +} + +func TestPopulateWebSearchUsage_WithManager_NilRedis(t *testing.T) { + // Manager with nil Redis returns 0 usage without error + mgr := websearch.NewManager([]websearch.ProviderConfig{ + {Type: "brave", APIKey: "k"}, + }, nil) + SetWebSearchManager(mgr) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)}, + }, + } + out := PopulateWebSearchUsage(context.Background(), cfg) + require.Equal(t, int64(0), out.Providers[0].QuotaUsed) + require.True(t, out.Providers[0].APIKeyConfigured) +} + +func TestPopulateWebSearchUsage_DoesNotMutateOriginal(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + cfg := &WebSearchEmulationConfig{ + Providers: []WebSearchProviderConfig{ + {Type: "brave", APIKey: "secret", QuotaLimit: int64Ptr(100)}, + }, + } + _ = PopulateWebSearchUsage(context.Background(), cfg) + // Original should be unchanged + require.Equal(t, "secret", cfg.Providers[0].APIKey) + require.Equal(t, int64(0), cfg.Providers[0].QuotaUsed) +} + +// --- ResetWebSearchUsage --- + +func TestResetWebSearchUsage_NilManager(t *testing.T) { + SetWebSearchManager(nil) + defer SetWebSearchManager(nil) + + err := ResetWebSearchUsage(context.Background(), "brave") + require.Error(t, err) + require.Contains(t, err.Error(), "not initialized") +} From d6965b0676eadba10ba499436302ccb8610af420 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 10:18:39 +0800 Subject: [PATCH 73/88] fix: resolve cherry-pick conflicts and restore compilation - Restore gateway_cache.go to upstream (no lua embeds) - Restore payment_order.go to upstream (use out_trade_no lookup) - Restore payment_fulfillment.go to upstream (same reason) - Add FeaturesConfig field and IsWebSearchEmulationEnabled to Channel - Add applyAccountStatsCost wrapper function - Add SettingKeyWebSearchEmulationConfig constant - Add WebSearchEmulationEnabled to SystemSettings - Add notify code rate limiting methods to EmailCache interface - Remove AllowUserRefund references (ent schema not present) - Fix duplicate import in payment_handler.go - Fix wire_gen.go argument mismatches --- backend/cmd/server/wire_gen.go | 4 +- backend/internal/handler/payment_handler.go | 1 - backend/internal/repository/gateway_cache.go | 257 +----------------- .../internal/service/account_stats_pricing.go | 21 ++ backend/internal/service/channel.go | 16 +- backend/internal/service/domain_constants.go | 3 + backend/internal/service/email_service.go | 4 + .../service/payment_config_providers.go | 21 +- .../service/payment_config_service.go | 2 - .../internal/service/payment_fulfillment.go | 14 +- backend/internal/service/payment_order.go | 254 +---------------- backend/internal/service/settings_view.go | 3 + 12 files changed, 80 insertions(+), 520 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 69daeecf..a0e84f4c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -143,7 +143,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) @@ -217,8 +217,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService) paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index e01a2af1..0425fc49 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -7,7 +7,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index ec4bf40e..58291b66 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,42 +2,14 @@ package repository import ( "context" - _ "embed" "fmt" - "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" ) -const ( - stickySessionPrefix = "sticky_session:" - clientAffinityPrefix = "client_affinity:" - clientAffinityReversePrefix = "client_affinity_rev:" -) - -var ( - //go:embed lua/get_affinity.lua - getAffinityLua string - //go:embed lua/update_affinity.lua - updateAffinityLua string - //go:embed lua/get_affinity_count.lua - getAffinityCountLua string - //go:embed lua/get_affinity_clients.lua - getAffinityClientsLua string - //go:embed lua/get_affinity_clients_with_scores.lua - getAffinityClientsWithScoresLua string - //go:embed lua/clear_account_affinity.lua - clearAccountAffinityLua string - - getAffinityScript = redis.NewScript(getAffinityLua) - updateAffinityScript = redis.NewScript(updateAffinityLua) - getAffinityCountScript = redis.NewScript(getAffinityCountLua) - getAffinityClientsScript = redis.NewScript(getAffinityClientsLua) - getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua) - clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua) -) +const stickySessionPrefix = "sticky_session:" type gatewayCache struct { rdb *redis.Client @@ -47,16 +19,6 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { return &gatewayCache{rdb: rdb} } -// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。 -// Pipeline 中的 Script.Run 只发送 EVALSHA,如果 Redis 重启过导致脚本缓存丢失, -// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。 -func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) { - exists, err := script.Exists(ctx, rdb).Result() - if err != nil || len(exists) == 0 || !exists[0] { - _ = script.Load(ctx, rdb).Err() - } -} - // buildSessionKey 构建 session key,包含 groupID 实现分组隔离 // 格式: sticky_session:{groupID}:{sessionHash} func buildSessionKey(groupID int64, sessionHash string) string { @@ -79,218 +41,13 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses } // DeleteSessionAccountID 删除粘性会话与账号的绑定关系。 +// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用, +// 以便下次请求能够重新选择可用账号。 +// +// DeleteSessionAccountID removes the sticky session binding for the given session. +// Called when the bound account becomes unavailable (e.g., error status, disabled, +// or unschedulable), allowing subsequent requests to select a new available account. func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } - -// buildAffinityKey 构建正向亲和 key(client → accounts) -// 格式: client_affinity:{groupID}:{clientID} -func buildAffinityKey(groupID int64, clientID string) string { - return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID) -} - -// buildAffinityReverseKey 构建反向亲和 key(account → clients) -// 格式: client_affinity_rev:{groupID}:{accountID} -func buildAffinityReverseKey(groupID int64, accountID int64) string { - return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID) -} - -func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) { - key := buildAffinityKey(groupID, clientID) - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice() - if err != nil { - if err == redis.Nil { - return nil, nil - } - return nil, err - } - - accountIDs := make([]int64, 0, len(result)) - for _, s := range result { - id, err := strconv.ParseInt(s, 10, 64) - if err != nil { - continue - } - accountIDs = append(accountIDs, id) - } - return accountIDs, nil -} - -func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error { - fwdKey := buildAffinityKey(groupID, clientID) - revKey := buildAffinityReverseKey(groupID, accountID) - now := time.Now().Unix() - ttlSeconds := int64(ttl.Seconds()) - expireThreshold := now - ttlSeconds - - return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey}, - now, ttlSeconds, accountID, expireThreshold, clientID, - ).Err() -} - -// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员) -func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) { - if len(accountIDs) == 0 { - return map[int64]int64{}, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(accountIDs)) - for i, accID := range accountIDs { - key := buildAffinityReverseKey(groupID, accID) - cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - result := make(map[int64]int64, len(accountIDs)) - for i, accID := range accountIDs { - count, _ := cmds[i].Int64() - result[accID] = count - } - return result, nil -} - -// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。 -// accountGroups: map[accountID][]groupID,对每个 (groupID, accountID) 组合查询反向索引。 -func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) { - if len(accountGroups) == 0 { - return map[int64][]string{}, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - // 构建所有 (accountID, groupID) 组合的查询 - type queryItem struct { - accountID int64 - groupID int64 - } - var queries []queryItem - for accID, groupIDs := range accountGroups { - for _, gID := range groupIDs { - queries = append(queries, queryItem{accountID: accID, groupID: gID}) - } - } - - ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(queries)) - for i, q := range queries { - key := buildAffinityReverseKey(q.groupID, q.accountID) - cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - // 合并结果:同一个 accountID 跨多个 group 的 clientID 去重 - result := make(map[int64][]string, len(accountGroups)) - seen := make(map[int64]map[string]struct{}, len(accountGroups)) - for i, q := range queries { - clients, _ := cmds[i].StringSlice() - if len(clients) == 0 { - continue - } - if seen[q.accountID] == nil { - seen[q.accountID] = make(map[string]struct{}) - } - for _, clientID := range clients { - if _, exists := seen[q.accountID][clientID]; !exists { - seen[q.accountID][clientID] = struct{}{} - result[q.accountID] = append(result[q.accountID], clientID) - } - } - } - return result, nil -} - -// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。 -func (c *gatewayCache) GetAccountAffinityClientsWithScores( - ctx context.Context, - accountID int64, - groupIDs []int64, - ttl time.Duration, -) ([]service.AffinityClient, error) { - if len(groupIDs) == 0 { - return nil, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(groupIDs)) - for i, gID := range groupIDs { - key := buildAffinityReverseKey(gID, accountID) - cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - // 合并跨组结果,同一 clientID 取最近的 lastActive - seen := make(map[string]int64) // clientID → max timestamp - for _, cmd := range cmds { - vals, _ := cmd.StringSlice() - // vals 格式: [clientID1, score1, clientID2, score2, ...] - for j := 0; j+1 < len(vals); j += 2 { - clientID := vals[j] - ts, _ := strconv.ParseInt(vals[j+1], 10, 64) - if existing, ok := seen[clientID]; !ok || ts > existing { - seen[clientID] = ts - } - } - } - - result := make([]service.AffinityClient, 0, len(seen)) - for clientID, ts := range seen { - result = append(result, service.AffinityClient{ - ClientID: clientID, - LastActive: time.Unix(ts, 0), - }) - } - - // 按最后活跃时间降序排序 - service.SortAffinityClients(result) - - return result, nil -} - -// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。 -// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端, -// 从每个客户端的正向索引中移除该账号,然后删除反向索引。 -func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error { - if len(groupIDs) == 0 { - return nil - } - - ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript) - - pipe := c.rdb.Pipeline() - for _, gID := range groupIDs { - revKey := buildAffinityReverseKey(gID, accountID) - clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return err - } - return nil -} diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 61c318d9..47b7496f 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -227,3 +227,24 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) * } return &cost } + +// applyAccountStatsCost resolves the account stats cost for a usage log entry. +// It resolves the upstream model (falling back to the requested model) and calls +// the 4-level priority chain via resolveAccountStatsCost. +func applyAccountStatsCost( + ctx context.Context, + usageLog *UsageLog, + cs *ChannelService, bs *BillingService, + accountID int64, groupID int64, + upstreamModel, requestedModel string, + tokens UsageTokens, + totalCost float64, +) { + model := upstreamModel + if model == "" { + model = requestedModel + } + usageLog.AccountStatsCost = resolveAccountStatsCost( + ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost, + ) +} diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index b034fda0..b3fb2eac 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -39,7 +39,8 @@ type Channel struct { Status string BillingModelSource string // "requested", "upstream", or "channel_mapped" RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) - Features string // 渠道特性描述(JSON 数组),用于支付页面展示 + Features string // 渠道特性描述(JSON 数组),用于支付页面展示 + FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation) CreatedAt time.Time UpdatedAt time.Time @@ -222,6 +223,19 @@ func (c *Channel) Clone() *Channel { return &cp } +// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。 +func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool { + if c == nil || c.FeaturesConfig == nil { + return false + } + wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any) + if !ok { + return false + } + enabled, ok := wse[platform].(bool) + return ok && enabled +} + // deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution. func deepCopyFeaturesConfig(src map[string]any) map[string]any { dst := make(map[string]any, len(src)) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index bdced29a..cb452efb 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -258,6 +258,9 @@ const ( // Account Quota Notification SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关 SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组) + + // Web Search Emulation + SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置 ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 9cfd3bbd..9a03ea30 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -49,6 +49,10 @@ type EmailCache interface { // Returns true if in cooldown period (email was sent recently) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error + + // Notify code rate limiting per user + IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) + GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) } // VerificationCodeData represents verification code data diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 0c71ab29..072ed002 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -30,7 +30,6 @@ type ProviderInstanceResponse struct { Limits string `json:"limits"` Enabled bool `json:"enabled"` RefundEnabled bool `json:"refund_enabled"` - AllowUserRefund bool `json:"allow_user_refund"` SortOrder int `json:"sort_order"` PaymentMode string `json:"payment_mode"` } @@ -47,7 +46,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte resp := ProviderInstanceResponse{ ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, - Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund, + Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } resp.Config, err = s.decryptAndMaskConfig(inst.Config) @@ -111,12 +110,10 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err != nil { return nil, err } - allowUserRefund := req.AllowUserRefund && req.RefundEnabled return s.entClient.PaymentProviderInstance.Create(). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). - SetAllowUserRefund(allowUserRefund). Save(ctx) } @@ -224,21 +221,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.RefundEnabled != nil { u.SetRefundEnabled(*req.RefundEnabled) - // Cascade: turning off refund_enabled also disables allow_user_refund - if !*req.RefundEnabled { - u.SetAllowUserRefund(false) - } - } - if req.AllowUserRefund != nil { - // Only allow enabling when refund_enabled is true - if *req.AllowUserRefund { - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err == nil && inst.RefundEnabled { - u.SetAllowUserRefund(true) - } - } else { - u.SetAllowUserRefund(false) - } } if req.PaymentMode != nil { u.SetPaymentMode(*req.PaymentMode) @@ -250,7 +232,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) { instances, err := s.entClient.PaymentProviderInstance.Query(). Where( - paymentproviderinstance.AllowUserRefundEQ(true), paymentproviderinstance.RefundEnabledEQ(true), ).Select(paymentproviderinstance.FieldID).All(ctx) if err != nil { diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index cce31f4d..6d470342 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -114,7 +114,6 @@ type CreateProviderInstanceRequest struct { SortOrder int `json:"sort_order"` Limits string `json:"limits"` RefundEnabled bool `json:"refund_enabled"` - AllowUserRefund bool `json:"allow_user_refund"` } type UpdateProviderInstanceRequest struct { @@ -126,7 +125,6 @@ type UpdateProviderInstanceRequest struct { SortOrder *int `json:"sort_order"` Limits *string `json:"limits"` RefundEnabled *bool `json:"refund_enabled"` - AllowUserRefund *bool `json:"allow_user_refund"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 51307849..de41d742 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "math" + "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -20,11 +22,17 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme if n.Status != payment.NotificationStatusSuccess { return nil } - oid, err := parseOrderID(n.OrderID) + // Look up order by out_trade_no (the external order ID we sent to the provider) + order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx) if err != nil { - return fmt.Errorf("invalid order ID: %s", n.OrderID) + // Fallback: try legacy format (sub2_N where N is DB ID) + trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix) + if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil { + return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + } + return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID) } - return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk) } func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index e81af3f5..ff4dfaa8 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -10,7 +10,6 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" @@ -170,68 +169,6 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us return nil } -func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error { - if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 { - return nil - } - windowStart := cancelRateLimitWindowStart(cfg) - operator := fmt.Sprintf("user:%d", userID) - count, err := s.entClient.PaymentAuditLog.Query(). - Where( - paymentauditlog.ActionEQ("ORDER_CANCELLED"), - paymentauditlog.OperatorEQ(operator), - paymentauditlog.CreatedAtGTE(windowStart), - ).Count(ctx) - if err != nil { - slog.Error("check cancel rate limit failed", "userID", userID, "error", err) - return nil // fail open - } - if count >= cfg.CancelRateLimitMax { - return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited"). - WithMetadata(map[string]string{ - "max": strconv.Itoa(cfg.CancelRateLimitMax), - "window": strconv.Itoa(cfg.CancelRateLimitWindow), - "unit": cfg.CancelRateLimitUnit, - }) - } - return nil -} - -func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time { - now := time.Now() - w := cfg.CancelRateLimitWindow - if w <= 0 { - w = 1 - } - unit := cfg.CancelRateLimitUnit - if unit == "" { - unit = "day" - } - if cfg.CancelRateLimitMode == "fixed" { - switch unit { - case "minute": - t := now.Truncate(time.Minute) - return t.Add(-time.Duration(w-1) * time.Minute) - case "day": - y, m, d := now.Date() - t := time.Date(y, m, d, 0, 0, 0, 0, now.Location()) - return t.AddDate(0, 0, -(w - 1)) - default: // hour - t := now.Truncate(time.Hour) - return t.Add(-time.Duration(w-1) * time.Hour) - } - } - // rolling window - switch unit { - case "minute": - return now.Add(-time.Duration(w) * time.Minute) - case "day": - return now.AddDate(0, 0, -w) - default: // hour - return now.Add(-time.Duration(w) * time.Hour) - } -} - func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error { if limit <= 0 { return nil @@ -252,19 +189,16 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user } func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) { - s.EnsureProviders(ctx) - providerKey := s.registry.GetProviderKey(req.PaymentType) - if providerKey == "" { - return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) - } - sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) + // Select an instance across all providers that support the requested payment type. + // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay"). + sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) if err != nil { - return nil, fmt.Errorf("select provider instance: %w", err) + return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) } if sel == nil { return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance") } - prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config) + prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config) if err != nil { return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable") } @@ -272,7 +206,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen outTradeNo := order.OutTradeNo pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes}) if err != nil { - slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err) + slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) } _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx) @@ -357,6 +291,13 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or if p.PaymentType != "" { q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType)) } + if p.Keyword != "" { + q = q.Where(paymentorder.Or( + paymentorder.OutTradeNoContainsFold(p.Keyword), + paymentorder.UserEmailContainsFold(p.Keyword), + paymentorder.UserNameContainsFold(p.Keyword), + )) + } total, err := q.Clone().Count(ctx) if err != nil { return nil, 0, fmt.Errorf("count admin orders: %w", err) @@ -368,172 +309,3 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or } return orders, total, nil } - -// --- Cancel & Expire --- - -func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) { - o, err := s.entClient.PaymentOrder.Get(ctx, orderID) - if err != nil { - return "", infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.UserID != userID { - return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order") - } - if o.Status != OrderStatusPending { - return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status") - } - return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order") -} - -func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) { - o, err := s.entClient.PaymentOrder.Get(ctx, orderID) - if err != nil { - return "", infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.Status != OrderStatusPending { - return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status") - } - return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order") -} - -func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) { - if o.PaymentTradeNo != "" || o.PaymentType != "" { - if s.checkPaid(ctx, o) == "already_paid" { - return "already_paid", nil - } - } - c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx) - if err != nil { - return "", fmt.Errorf("update order status: %w", err) - } - if c > 0 { - auditAction := "ORDER_CANCELLED" - if fs == OrderStatusExpired { - auditAction = "ORDER_EXPIRED" - } - s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad}) - } - return "cancelled", nil -} - -func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string { - prov, err := s.getOrderProvider(ctx, o) - if err != nil { - return "" - } - // Use OutTradeNo as fallback when PaymentTradeNo is empty - // (e.g. EasyPay popup mode where trade_no arrives only via notify callback) - tradeNo := o.PaymentTradeNo - if tradeNo == "" { - tradeNo = o.OutTradeNo - } - resp, err := prov.QueryOrder(ctx, tradeNo) - if err != nil { - slog.Warn("query upstream failed", "orderID", o.ID, "error", err) - return "" - } - if resp.Status == payment.ProviderStatusPaid { - if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil { - slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) - // Still return already_paid — order was paid, fulfillment can be retried - } - return "already_paid" - } - if cp, ok := prov.(payment.CancelableProvider); ok { - _ = cp.CancelPayment(ctx, tradeNo) - } - return "" -} - -// VerifyOrderByOutTradeNo actively queries the upstream provider to check -// if a payment was made, and processes it if so. This handles the case where -// the provider's notify callback was missed (e.g. EasyPay popup mode). -func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { - o, err := s.entClient.PaymentOrder.Query(). - Where(paymentorder.OutTradeNo(outTradeNo)). - Only(ctx) - if err != nil { - return nil, infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.UserID != userID { - return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order") - } - // Only verify orders that are still pending or recently expired - if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) - if result == "already_paid" { - // Reload order to get updated status - o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) - if err != nil { - return nil, fmt.Errorf("reload order: %w", err) - } - } - } - return o, nil -} - -// VerifyOrderPublic verifies payment status without user authentication. -// Used by the payment result page when the user's session has expired. -func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { - o, err := s.entClient.PaymentOrder.Query(). - Where(paymentorder.OutTradeNo(outTradeNo)). - Only(ctx) - if err != nil { - return nil, infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) - if result == "already_paid" { - o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) - if err != nil { - return nil, fmt.Errorf("reload order: %w", err) - } - } - } - return o, nil -} - -func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { - now := time.Now() - orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) - if err != nil { - return 0, fmt.Errorf("query expired: %w", err) - } - n := 0 - for _, o := range orders { - // Check upstream payment status before expiring — the user may have - // paid just before timeout and the webhook hasn't arrived yet. - outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired") - if outcome == "already_paid" { - slog.Info("order was paid during expiry", "orderID", o.ID) - continue - } - if outcome != "" { - n++ - } - } - return n, nil -} - -// getOrderProvider creates a provider using the order's original instance config. -// Falls back to registry lookup if instance ID is missing (legacy orders). -func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) - if err == nil { - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) - if err == nil { - providerKey := s.registry.GetProviderKey(o.PaymentType) - if providerKey == "" { - providerKey = o.PaymentType - } - p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) - if err == nil { - return p, nil - } - } - } - } - s.EnsureProviders(ctx) - return s.registry.GetProvider(o.PaymentType) -} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ec20fe0a..ab2eb274 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -107,6 +107,9 @@ type SystemSettings struct { EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) + // Web Search Emulation + WebSearchEmulationEnabled bool // 是否启用 web search 模拟 + // Balance low notification BalanceLowNotifyEnabled bool BalanceLowNotifyThreshold float64 From 24e16b7f599eba35e180ae2e53b5e042775a6421 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 10:58:51 +0800 Subject: [PATCH 74/88] fix: restore resolveOpenAIMessagesDispatchMappedModel and reset VERSION - Restore function deleted during cherry-pick conflict resolution - Reset VERSION to upstream 0.1.112 --- backend/cmd/server/VERSION | 2 +- backend/internal/handler/openai_gateway_handler.go | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 630554d9..4b9b35d8 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.112.4 +0.1.112 diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index dda6d2e3..6c5a6779 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode return strings.TrimSpace(apiKey.Group.DefaultMappedModel) } +func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { + if apiKey == nil || apiKey.Group == nil { + return "" + } + return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel)) +} + // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, From b42f34c359251bd374d957092915a7e79616e0db Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 11:27:32 +0800 Subject: [PATCH 75/88] fix: resolve test compilation errors and restore upstream VERSION - Add missing interface methods to test stubs (RemoveGroupFromUserAllowedGroups, GetNotifyCodeUserRate, IncrNotifyCodeUserRate, UpdateGroupIDByUserAndGroup) - Fix NewUserService call signatures (add 4th param) - Fix GetAccountCount return signature (3 values) - Update api_contract_test.go snapshots for balance_notify fields - Restore resolveOpenAIMessagesDispatchMappedModel function - Reset VERSION to upstream 0.1.112 --- backend/internal/server/api_contract_test.go | 14 ++++++++++++-- .../internal/server/middleware/admin_auth_test.go | 2 +- .../internal/server/middleware/jwt_auth_test.go | 2 +- .../internal/service/admin_service_apikey_test.go | 8 +++++++- .../internal/service/auth_service_register_test.go | 8 ++++++++ backend/internal/service/user_service_test.go | 3 +++ 6 files changed, 32 insertions(+), 5 deletions(-) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 08291faa..44c3f0e4 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -58,6 +58,11 @@ func TestAPIContracts(t *testing.T) { "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z", + "balance_notify_enabled": false, + "balance_notify_threshold_type": "", + "balance_notify_threshold": null, + "balance_notify_extra_emails": null, + "total_recharged": 0, "run_mode": "standard" } }`, @@ -606,7 +611,12 @@ func TestAPIContracts(t *testing.T) { "payment_cancel_rate_limit_max": 0, "payment_cancel_rate_limit_window": 0, "payment_cancel_rate_limit_unit": "", - "payment_cancel_rate_limit_window_mode": "" + "payment_cancel_rate_limit_window_mode": "", + "balance_low_notify_enabled": false, + "account_quota_notify_enabled": false, + "balance_low_notify_threshold": 0, + "balance_low_notify_recharge_url": "", + "account_quota_notify_emails": [] } }`, }, @@ -699,7 +709,7 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo, nil, nil) + userService := service.NewUserService(userRepo, nil, nil, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index aafe4a58..ed2578c8 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { return &clone, nil }, } - userService := service.NewUserService(userRepo, nil, nil) + userService := service.NewUserService(userRepo, nil, nil, nil) router := gin.New() router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index ad9c1b5b..c483a51e 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -41,7 +41,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer userRepo := &stubJWTUserRepo{users: users} authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) - userSvc := service.NewUserService(userRepo, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) r := gin.New() diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 5c18a438..7f0a24da 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -70,6 +70,9 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected") +} // apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. type apiKeyRepoStubForGroupUpdate struct { @@ -152,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) { + panic("unexpected") +} // groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests. type groupRepoStubForGroupUpdate struct { @@ -194,7 +200,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { panic("unexpected") } -func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) { panic("unexpected") } func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 0999b4f0..103bafe7 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -119,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai return nil } +func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) { + return 0, nil +} + +func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) { + return 0, nil +} + func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService { cfg := &config.Config{ JWT: config.JWTConfig{ diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 29267c19..a998d5f4 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -49,6 +49,9 @@ func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) er func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} // --- mock: APIKeyAuthCacheInvalidator --- From 4aa0070e3d0ac672e514713eec5b8194ce4160b2 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 11:31:44 +0800 Subject: [PATCH 76/88] fix: Stripe payment type matching in load balancer Checkout page aggregates Stripe sub-types (card,link,alipay,wxpay) under "stripe", but SelectInstance matched against supported_types literally, which doesn't contain "stripe". Now matches by provider_key for Stripe. --- backend/internal/payment/load_balancer.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index 55cb2043..f0353173 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( var matched []*dbent.PaymentProviderInstance for _, inst := range instances { - if InstanceSupportsType(inst.SupportedTypes, paymentType) { + // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay), + // not "stripe" itself. The checkout page aggregates all sub-types under "stripe". + if paymentType == TypeStripe { + if inst.ProviderKey == TypeStripe { + matched = append(matched, inst) + } + } else if InstanceSupportsType(inst.SupportedTypes, paymentType) { matched = append(matched, inst) } } From 6a08efeef9d013d4e5b3551b8c800adca267db63 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 12:11:08 +0800 Subject: [PATCH 77/88] fix: resolve upstream CI failures (lint, test, gofmt) - Fix errcheck: handle Write/Encode return values in brave_test.go - Fix errcheck: defer resp.Body.Close() with _ assignment in tavily.go - Fix gofmt: payment.go, channel.go, payment_config_providers.go - Fix unused: remove dead decodeURLValue in easypay.go - Restore shouldFallbackGeminiModel function (deleted during cherry-pick) - Add missing balanceNotifyService param to NewGatewayService in test - Fix platform default test expectation (empty stays empty) - Fix wildcard pricing test (longest prefix wins, not config order) - Fix subscription group test (SUBSCRIPTION_REPOSITORY_UNAVAILABLE) --- .../handler/admin/channel_handler_test.go | 4 ++-- ...eway_handler_warmup_intercept_unit_test.go | 1 + .../internal/handler/gemini_v1beta_handler.go | 10 ++++++++++ backend/internal/payment/provider/easypay.go | 9 --------- backend/internal/pkg/websearch/brave_test.go | 10 +++++----- backend/internal/pkg/websearch/tavily.go | 2 +- backend/internal/server/routes/payment.go | 1 - .../service/account_stats_pricing_test.go | 4 ++-- .../service/admin_service_apikey_test.go | 4 ++-- backend/internal/service/channel.go | 4 ++-- .../service/payment_config_providers.go | 20 +++++++++---------- 11 files changed, 35 insertions(+), 34 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go index 2f4b4440..f218cce4 100644 --- a/backend/internal/handler/admin/channel_handler_test.go +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) { wantValue: string(service.BillingModeToken), }, { - name: "empty platform defaults to anthropic", + name: "empty platform stays empty", req: channelModelPricingRequest{ Models: []string{"m1"}, Platform: "", }, wantField: "Platform", - wantValue: "anthropic", + wantValue: "", }, } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index acea3780..1fdc46ba 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // tlsFPProfileService nil, // channelService nil, // resolver + nil, // balanceNotifyService ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 45b5842f..6b8cc482 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -682,6 +682,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { return false } +func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool { + if shouldFallbackGeminiModels(res) { + return true + } + if res == nil || res.StatusCode != http.StatusNotFound { + return false + } + return gemini.HasFallbackModel(modelName) +} + // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go index c54aba6a..b48a38fe 100644 --- a/backend/internal/payment/provider/easypay.go +++ b/backend/internal/payment/provider/easypay.go @@ -276,12 +276,3 @@ func easyPaySign(params map[string]string, pkey string) string { func easyPayVerifySign(params map[string]string, pkey string, sign string) bool { return hmac.Equal([]byte(easyPaySign(params, pkey)), []byte(sign)) } - -// decodeURLValue URL-decodes a string once. -func decodeURLValue(s string) string { - decoded, err := url.QueryUnescape(s) - if err != nil { - return s - } - return decoded -} diff --git a/backend/internal/pkg/websearch/brave_test.go b/backend/internal/pkg/websearch/brave_test.go index 3fe35020..4dc5b219 100644 --- a/backend/internal/pkg/websearch/brave_test.go +++ b/backend/internal/pkg/websearch/brave_test.go @@ -29,7 +29,7 @@ func TestBraveProvider_Search_Success(t *testing.T) { {URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"}, } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() @@ -53,7 +53,7 @@ func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { receivedCount = r.URL.Query().Get("count") resp := braveResponse{} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() @@ -70,7 +70,7 @@ func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) { func TestBraveProvider_Search_HTTPError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(429) - w.Write([]byte("rate limited")) + _, _ = w.Write([]byte("rate limited")) })) defer srv.Close() @@ -86,7 +86,7 @@ func TestBraveProvider_Search_HTTPError(t *testing.T) { func TestBraveProvider_Search_InvalidJSON(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("not json")) + _, _ = w.Write([]byte("not json")) })) defer srv.Close() @@ -103,7 +103,7 @@ func TestBraveProvider_Search_InvalidJSON(t *testing.T) { func TestBraveProvider_Search_EmptyResults(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { resp := braveResponse{} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() diff --git a/backend/internal/pkg/websearch/tavily.go b/backend/internal/pkg/websearch/tavily.go index 6ac09edf..ac4928a6 100644 --- a/backend/internal/pkg/websearch/tavily.go +++ b/backend/internal/pkg/websearch/tavily.go @@ -60,7 +60,7 @@ func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*Search if err != nil { return nil, fmt.Errorf("tavily: request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) if err != nil { diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 641c6cd5..72012a4e 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -78,7 +78,6 @@ func RegisterPaymentRoutes( adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund) } - // Subscription Plans plans := adminGroup.Group("/plans") { diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index 36e5eb74..2f625393 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) { wantNil: true, }, { - name: "wildcard matches by config order (first match wins)", + name: "wildcard matches by longest prefix (most specific wins)", list: []ChannelModelPricing{ {ID: 10, Models: []string{"claude-*"}}, {ID: 11, Models: []string{"claude-opus-*"}}, }, platform: "", model: "claude-opus-4", - wantID: 10, // config order: "claude-*" is first and matches, so it wins + wantID: 11, // "claude-opus-*" is longer prefix, wins over "claude-*" }, { name: "shorter wildcard used when longer does not match", diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 7f0a24da..1e235278 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -412,10 +412,10 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *test userRepo := &userRepoStubForGroupUpdate{} svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} - // 订阅类型分组应被阻止绑定 + // userSubRepo is nil → SUBSCRIPTION_REPOSITORY_UNAVAILABLE _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) + require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err)) require.False(t, userRepo.addGroupCalled) } diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index b3fb2eac..93beb972 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -37,8 +37,8 @@ type Channel struct { Name string Description string Status string - BillingModelSource string // "requested", "upstream", or "channel_mapped" - RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) + BillingModelSource string // "requested", "upstream", or "channel_mapped" + RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) Features string // 渠道特性描述(JSON 数组),用于支付页面展示 FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation) CreatedAt time.Time diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 072ed002..47008df0 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -22,16 +22,16 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db // ProviderInstanceResponse is the API response for a provider instance. type ProviderInstanceResponse struct { - ID int64 `json:"id"` - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Limits string `json:"limits"` - Enabled bool `json:"enabled"` - RefundEnabled bool `json:"refund_enabled"` - SortOrder int `json:"sort_order"` - PaymentMode string `json:"payment_mode"` + ID int64 `json:"id"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Limits string `json:"limits"` + Enabled bool `json:"enabled"` + RefundEnabled bool `json:"refund_enabled"` + SortOrder int `json:"sort_order"` + PaymentMode string `json:"payment_mode"` } // ListProviderInstancesWithConfig returns provider instances with decrypted config. From e8ee400a3f8720345f025f8ba8e6164d0950592c Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 12:19:44 +0800 Subject: [PATCH 78/88] fix: resolve remaining lint errors for upstream CI - Fix errcheck: brave.go resp.Body.Close, manager_test.go Encode - Fix gofmt: payment_config_service.go - Fix unused: use shouldFallbackGeminiModel (with modelName param) in handler --- .../internal/handler/gemini_v1beta_handler.go | 2 +- backend/internal/pkg/websearch/brave.go | 2 +- .../internal/pkg/websearch/manager_test.go | 4 +-- .../service/payment_config_service.go | 34 +++++++++---------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 6b8cc482..d200c17c 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusBadGateway, err.Error()) return } - if shouldFallbackGeminiModels(res) { + if shouldFallbackGeminiModel(modelName, res) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } diff --git a/backend/internal/pkg/websearch/brave.go b/backend/internal/pkg/websearch/brave.go index 5620ca8d..707e7029 100644 --- a/backend/internal/pkg/websearch/brave.go +++ b/backend/internal/pkg/websearch/brave.go @@ -62,7 +62,7 @@ func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchR if err != nil { return nil, fmt.Errorf("brave: request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) if err != nil { diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go index cbcf1b76..a4413417 100644 --- a/backend/internal/pkg/websearch/manager_test.go +++ b/backend/internal/pkg/websearch/manager_test.go @@ -50,7 +50,7 @@ func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) { srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { resp := braveResponse{} resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srvBrave.Close() @@ -77,7 +77,7 @@ func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { resp := braveResponse{} resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 6d470342..9042c3ab 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -105,26 +105,26 @@ type MethodLimitsResponse struct { } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled bool `json:"enabled"` - PaymentMode string `json:"payment_mode"` - SortOrder int `json:"sort_order"` - Limits string `json:"limits"` - RefundEnabled bool `json:"refund_enabled"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled bool `json:"enabled"` + PaymentMode string `json:"payment_mode"` + SortOrder int `json:"sort_order"` + Limits string `json:"limits"` + RefundEnabled bool `json:"refund_enabled"` } type UpdateProviderInstanceRequest struct { - Name *string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled *bool `json:"enabled"` - PaymentMode *string `json:"payment_mode"` - SortOrder *int `json:"sort_order"` - Limits *string `json:"limits"` - RefundEnabled *bool `json:"refund_enabled"` + Name *string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled *bool `json:"enabled"` + PaymentMode *string `json:"payment_mode"` + SortOrder *int `json:"sort_order"` + Limits *string `json:"limits"` + RefundEnabled *bool `json:"refund_enabled"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` From f1297a3694973d7ba9274d7b5e5874aafa1dfa56 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 16:26:46 +0800 Subject: [PATCH 79/88] feat: add per-provider allow_user_refund control and align wildcard matching allow_user_refund: - Add allow_user_refund field to PaymentProviderInstance ent schema - Migration 103: ALTER TABLE payment_provider_instances ADD COLUMN - Cascade logic: disabling refund_enabled auto-disables allow_user_refund - User refund validation: check provider instance allows user refund - Admin refund validation: check provider instance allows admin refund - Subscription refund: deduct days on refund, rollback on failure - New endpoint: GET /payment/orders/refund-eligible-providers - Frontend: ToggleSwitch in ProviderCard/Dialog, cascade in SettingsView Wildcard matching: - Change findPricingForModel from "longest prefix wins" to "config order priority (first match wins)", aligning with channel service behavior --- backend/ent/client.go | 36 +++---- backend/ent/intercept/intercept.go | 1 - backend/ent/migrate/schema.go | 1 + backend/ent/mutation.go | 94 +++++++++++++++---- backend/ent/paymentproviderinstance.go | 13 ++- .../paymentproviderinstance.go | 10 ++ backend/ent/paymentproviderinstance/where.go | 15 +++ backend/ent/paymentproviderinstance_create.go | 65 +++++++++++++ backend/ent/paymentproviderinstance_update.go | 34 +++++++ backend/ent/predicate/predicate.go | 1 - backend/ent/runtime/runtime.go | 8 +- .../ent/schema/payment_provider_instance.go | 2 + backend/internal/handler/payment_handler.go | 10 ++ backend/internal/server/routes/payment.go | 1 + .../internal/service/account_stats_pricing.go | 22 +---- .../service/account_stats_pricing_test.go | 4 +- .../service/payment_config_providers.go | 42 ++++++--- .../service/payment_config_service.go | 36 +++---- backend/internal/service/payment_refund.go | 64 +++++++++++++ .../migrations/103_add_allow_user_refund.sql | 1 + frontend/src/api/payment.ts | 5 + .../payment/PaymentProviderDialog.vue | 8 +- .../payment/PaymentProviderList.vue | 2 +- .../src/components/payment/ProviderCard.vue | 3 +- frontend/src/i18n/locales/en.ts | 1 + frontend/src/i18n/locales/zh.ts | 1 + frontend/src/types/payment.ts | 2 + frontend/src/views/admin/SettingsView.vue | 21 ++++- 28 files changed, 405 insertions(+), 98 deletions(-) create mode 100644 backend/migrations/103_add_allow_user_refund.sql diff --git a/backend/ent/client.go b/backend/ent/client.go index 3da7acf8..e52e015a 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -333,10 +333,10 @@ func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, - c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, - c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, + c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, + c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -349,10 +349,10 @@ func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, - c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, - c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, + c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, + c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -4629,19 +4629,19 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, - PaymentOrder, PaymentProviderInstance, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, - TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, - PaymentOrder, PaymentProviderInstance, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, - TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 77d3e16e..8d8320bb 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -336,7 +336,6 @@ func (f TraversePaymentAuditLog) Traverse(ctx context.Context, q ent.Query) erro return fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q) } - // The PaymentOrderFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentOrderFunc func(context.Context, *ent.PaymentOrderQuery) (ent.Value, error) diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 1fff61ba..68bdbf55 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -616,6 +616,7 @@ var ( {Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, {Name: "refund_enabled", Type: field.TypeBool, Default: false}, + {Name: "allow_user_refund", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 3bca248d..524ccb92 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -15642,25 +15642,26 @@ func (m *PaymentOrderMutation) ResetEdge(name string) error { // PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. type PaymentProviderInstanceMutation struct { config - op Op - typ string - id *int64 - provider_key *string - name *string - _config *string - supported_types *string - enabled *bool - payment_mode *string - sort_order *int - addsort_order *int - limits *string - refund_enabled *bool - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentProviderInstance, error) - predicates []predicate.PaymentProviderInstance + op Op + typ string + id *int64 + provider_key *string + name *string + _config *string + supported_types *string + enabled *bool + payment_mode *string + sort_order *int + addsort_order *int + limits *string + refund_enabled *bool + allow_user_refund *bool + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentProviderInstance, error) + predicates []predicate.PaymentProviderInstance } var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) @@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { m.refund_enabled = nil } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { + m.allow_user_refund = &b +} + +// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. +func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { + v := m.allow_user_refund + if v == nil { + return + } + return *v, true +} + +// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) + } + return oldValue.AllowUserRefund, nil +} + +// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { + m.allow_user_refund = nil +} + // SetCreatedAt sets the "created_at" field. func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *PaymentProviderInstanceMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 12) if m.provider_key != nil { fields = append(fields, paymentproviderinstance.FieldProviderKey) } @@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string { if m.refund_enabled != nil { fields = append(fields, paymentproviderinstance.FieldRefundEnabled) } + if m.allow_user_refund != nil { + fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) + } if m.created_at != nil { fields = append(fields, paymentproviderinstance.FieldCreatedAt) } @@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { return m.Limits() case paymentproviderinstance.FieldRefundEnabled: return m.RefundEnabled() + case paymentproviderinstance.FieldAllowUserRefund: + return m.AllowUserRefund() case paymentproviderinstance.FieldCreatedAt: return m.CreatedAt() case paymentproviderinstance.FieldUpdatedAt: @@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str return m.OldLimits(ctx) case paymentproviderinstance.FieldRefundEnabled: return m.OldRefundEnabled(ctx) + case paymentproviderinstance.FieldAllowUserRefund: + return m.OldAllowUserRefund(ctx) case paymentproviderinstance.FieldCreatedAt: return m.OldCreatedAt(ctx) case paymentproviderinstance.FieldUpdatedAt: @@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) } m.SetRefundEnabled(v) return nil + case paymentproviderinstance.FieldAllowUserRefund: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowUserRefund(v) + return nil case paymentproviderinstance.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error { case paymentproviderinstance.FieldRefundEnabled: m.ResetRefundEnabled() return nil + case paymentproviderinstance.FieldAllowUserRefund: + m.ResetAllowUserRefund() + return nil case paymentproviderinstance.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/ent/paymentproviderinstance.go b/backend/ent/paymentproviderinstance.go index 087cb13a..4279b86e 100644 --- a/backend/ent/paymentproviderinstance.go +++ b/backend/ent/paymentproviderinstance.go @@ -35,6 +35,8 @@ type PaymentProviderInstance struct { Limits string `json:"limits,omitempty"` // RefundEnabled holds the value of the "refund_enabled" field. RefundEnabled bool `json:"refund_enabled,omitempty"` + // AllowUserRefund holds the value of the "allow_user_refund" field. + AllowUserRefund bool `json:"allow_user_refund,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled: + case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund: values[i] = new(sql.NullBool) case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder: values[i] = new(sql.NullInt64) @@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any) } else if value.Valid { _m.RefundEnabled = value.Bool } + case paymentproviderinstance.FieldAllowUserRefund: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i]) + } else if value.Valid { + _m.AllowUserRefund = value.Bool + } case paymentproviderinstance.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string { builder.WriteString("refund_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled)) builder.WriteString(", ") + builder.WriteString("allow_user_refund=") + builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/ent/paymentproviderinstance/paymentproviderinstance.go b/backend/ent/paymentproviderinstance/paymentproviderinstance.go index c430fef6..eb1b0c52 100644 --- a/backend/ent/paymentproviderinstance/paymentproviderinstance.go +++ b/backend/ent/paymentproviderinstance/paymentproviderinstance.go @@ -31,6 +31,8 @@ const ( FieldLimits = "limits" // FieldRefundEnabled holds the string denoting the refund_enabled field in the database. FieldRefundEnabled = "refund_enabled" + // FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database. + FieldAllowUserRefund = "allow_user_refund" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -51,6 +53,7 @@ var Columns = []string{ FieldSortOrder, FieldLimits, FieldRefundEnabled, + FieldAllowUserRefund, FieldCreatedAt, FieldUpdatedAt, } @@ -88,6 +91,8 @@ var ( DefaultLimits string // DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field. DefaultRefundEnabled bool + // DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field. + DefaultAllowUserRefund bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc() } +// ByAllowUserRefund orders the results by the allow_user_refund field. +func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/paymentproviderinstance/where.go b/backend/ent/paymentproviderinstance/where.go index 7b99517f..40e5a1f6 100644 --- a/backend/ent/paymentproviderinstance/where.go +++ b/backend/ent/paymentproviderinstance/where.go @@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v)) } +// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ. +func AllowUserRefund(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) @@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v)) } +// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field. +func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v)) +} + +// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field. +func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/paymentproviderinstance_create.go b/backend/ent/paymentproviderinstance_create.go index 20b16ddd..d1b14617 100644 --- a/backend/ent/paymentproviderinstance_create.go +++ b/backend/ent/paymentproviderinstance_create.go @@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym return _c } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate { + _c.mutation.SetAllowUserRefund(v) + return _c +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate { + if v != nil { + _c.SetAllowUserRefund(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate { _c.mutation.SetCreatedAt(v) @@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() { v := paymentproviderinstance.DefaultRefundEnabled _c.mutation.SetRefundEnabled(v) } + if _, ok := _c.mutation.AllowUserRefund(); !ok { + v := paymentproviderinstance.DefaultAllowUserRefund + _c.mutation.SetAllowUserRefund(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := paymentproviderinstance.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error { if _, ok := _c.mutation.RefundEnabled(); !ok { return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)} } + if _, ok := _c.mutation.AllowUserRefund(); !ok { + return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)} + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)} } @@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance, _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _node.RefundEnabled = value } + if value, ok := _c.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + _node.AllowUserRefund = value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn return u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert { + u.Set(paymentproviderinstance.FieldAllowUserRefund, v) + return u +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert { + u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund) + return u +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert { u.Set(paymentproviderinstance.FieldUpdatedAt, v) @@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide }) } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.SetAllowUserRefund(v) + }) +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.UpdateAllowUserRefund() + }) +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne { return u.Update(func(s *PaymentProviderInstanceUpsert) { @@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid }) } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.SetAllowUserRefund(v) + }) +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.UpdateAllowUserRefund() + }) +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk { return u.Update(func(s *PaymentProviderInstanceUpsert) { diff --git a/backend/ent/paymentproviderinstance_update.go b/backend/ent/paymentproviderinstance_update.go index 06dba527..6bb3a82d 100644 --- a/backend/ent/paymentproviderinstance_update.go +++ b/backend/ent/paymentproviderinstance_update.go @@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym return _u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate { + _u.mutation.SetAllowUserRefund(v) + return _u +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate { + if v != nil { + _u.SetAllowUserRefund(*v) + } + return _u +} + // SetUpdatedAt sets the "updated_at" field. func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate { _u.mutation.SetUpdatedAt(v) @@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int if value, ok := _u.mutation.RefundEnabled(); ok { _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + } if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) } @@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P return _u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne { + _u.mutation.SetAllowUserRefund(v) + return _u +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne { + if v != nil { + _u.SetAllowUserRefund(*v) + } + return _u +} + // SetUpdatedAt sets the "updated_at" field. func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne { _u.mutation.SetUpdatedAt(v) @@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node if value, ok := _u.mutation.RefundEnabled(); ok { _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + } if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 67f37c75..ef551940 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -33,7 +33,6 @@ type IdempotencyRecord func(*sql.Selector) // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) - // PaymentOrder is the predicate function for paymentorder builders. type PaymentOrder func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 951b5f99..fbdd08c7 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -668,12 +668,16 @@ func init() { paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor() // paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field. paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool) + // paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field. + paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor() + // paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field. + paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool) // paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field. - paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor() + paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor() // paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field. paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time) // paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field. - paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor() + paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor() // paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field. paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/ent/schema/payment_provider_instance.go b/backend/ent/schema/payment_provider_instance.go index 08ab7d31..e4c0b72c 100644 --- a/backend/ent/schema/payment_provider_instance.go +++ b/backend/ent/schema/payment_provider_instance.go @@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field { Default(""), field.Bool("refund_enabled"). Default(false), + field.Bool("allow_user_refund"). + Default(false), field.Time("created_at"). Immutable(). Default(time.Now). diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 0425fc49..5fde86fa 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) { response.Success(c, gin.H{"message": "refund requested"}) } +// GetRefundEligibleProviders returns provider instance IDs that allow user refund. +func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) { + ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"provider_instance_ids": ids}) +} + // VerifyOrderRequest is the request body for verifying a payment order. type VerifyOrderRequest struct { OutTradeNo string `json:"out_trade_no" binding:"required"` diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 72012a4e..8def7559 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -37,6 +37,7 @@ func RegisterPaymentRoutes( orders.GET("/:id", paymentHandler.GetOrder) orders.POST("/:id/cancel", paymentHandler.CancelOrder) orders.POST("/:id/refund-request", paymentHandler.RequestRefund) + orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders) } } diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 47b7496f..90ff450f 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -2,7 +2,6 @@ package service import ( "context" - "sort" "strings" ) @@ -116,14 +115,8 @@ func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int return false } -// wildcardMatch 通配符匹配候选项(用于排序) -type wildcardMatch struct { - prefixLen int - pricing *ChannelModelPricing -} - // findPricingForModel 在定价列表中查找匹配的模型定价。 -// 先精确匹配,再通配符匹配(前缀越长优先级越高)。 +// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。 func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing { // 精确匹配优先 for i := range pricingList { @@ -137,8 +130,7 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower } } } - // 通配符匹配:收集所有匹配项,按前缀长度降序取最长 - var matches []wildcardMatch + // 通配符匹配:按配置顺序,先匹配先使用 for i := range pricingList { p := &pricingList[i] if !isPlatformMatch(platform, p.Platform) { @@ -151,17 +143,11 @@ func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower } prefix := strings.TrimSuffix(ml, "*") if strings.HasPrefix(modelLower, prefix) { - matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p}) + return p } } } - if len(matches) == 0 { - return nil - } - sort.Slice(matches, func(i, j int) bool { - return matches[i].prefixLen > matches[j].prefixLen - }) - return matches[0].pricing + return nil } // isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。 diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index 2f625393..36e5eb74 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -147,14 +147,14 @@ func TestFindPricingForModel(t *testing.T) { wantNil: true, }, { - name: "wildcard matches by longest prefix (most specific wins)", + name: "wildcard matches by config order (first match wins)", list: []ChannelModelPricing{ {ID: 10, Models: []string{"claude-*"}}, {ID: 11, Models: []string{"claude-opus-*"}}, }, platform: "", model: "claude-opus-4", - wantID: 11, // "claude-opus-*" is longer prefix, wins over "claude-*" + wantID: 10, // config order: "claude-*" is first and matches, so it wins }, { name: "shorter wildcard used when longer does not match", diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 47008df0..0f7cb99a 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db // ProviderInstanceResponse is the API response for a provider instance. type ProviderInstanceResponse struct { - ID int64 `json:"id"` - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Limits string `json:"limits"` - Enabled bool `json:"enabled"` - RefundEnabled bool `json:"refund_enabled"` - SortOrder int `json:"sort_order"` - PaymentMode string `json:"payment_mode"` + ID int64 `json:"id"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Limits string `json:"limits"` + Enabled bool `json:"enabled"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` + SortOrder int `json:"sort_order"` + PaymentMode string `json:"payment_mode"` } // ListProviderInstancesWithConfig returns provider instances with decrypted config. @@ -47,7 +48,8 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, - SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, + AllowUserRefund: inst.AllowUserRefund, + SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } resp.Config, err = s.decryptAndMaskConfig(inst.Config) if err != nil { @@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err != nil { return nil, err } + allowUserRefund := req.AllowUserRefund && req.RefundEnabled return s.entClient.PaymentProviderInstance.Create(). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). + SetAllowUserRefund(allowUserRefund). Save(ctx) } @@ -221,6 +225,21 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.RefundEnabled != nil { u.SetRefundEnabled(*req.RefundEnabled) + // Cascade: turning off refund_enabled also disables allow_user_refund + if !*req.RefundEnabled { + u.SetAllowUserRefund(false) + } + } + if req.AllowUserRefund != nil { + // Only allow enabling when refund_enabled is true + if *req.AllowUserRefund { + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err == nil && inst.RefundEnabled { + u.SetAllowUserRefund(true) + } + } else { + u.SetAllowUserRefund(false) + } } if req.PaymentMode != nil { u.SetPaymentMode(*req.PaymentMode) @@ -233,6 +252,7 @@ func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Cont instances, err := s.entClient.PaymentProviderInstance.Query(). Where( paymentproviderinstance.RefundEnabledEQ(true), + paymentproviderinstance.AllowUserRefundEQ(true), ).Select(paymentproviderinstance.FieldID).All(ctx) if err != nil { return nil, err diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 9042c3ab..cce31f4d 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -105,26 +105,28 @@ type MethodLimitsResponse struct { } type CreateProviderInstanceRequest struct { - ProviderKey string `json:"provider_key"` - Name string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled bool `json:"enabled"` - PaymentMode string `json:"payment_mode"` - SortOrder int `json:"sort_order"` - Limits string `json:"limits"` - RefundEnabled bool `json:"refund_enabled"` + ProviderKey string `json:"provider_key"` + Name string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled bool `json:"enabled"` + PaymentMode string `json:"payment_mode"` + SortOrder int `json:"sort_order"` + Limits string `json:"limits"` + RefundEnabled bool `json:"refund_enabled"` + AllowUserRefund bool `json:"allow_user_refund"` } type UpdateProviderInstanceRequest struct { - Name *string `json:"name"` - Config map[string]string `json:"config"` - SupportedTypes []string `json:"supported_types"` - Enabled *bool `json:"enabled"` - PaymentMode *string `json:"payment_mode"` - SortOrder *int `json:"sort_order"` - Limits *string `json:"limits"` - RefundEnabled *bool `json:"refund_enabled"` + Name *string `json:"name"` + Config map[string]string `json:"config"` + SupportedTypes []string `json:"supported_types"` + Enabled *bool `json:"enabled"` + PaymentMode *string `json:"payment_mode"` + SortOrder *int `json:"sort_order"` + Limits *string `json:"limits"` + RefundEnabled *bool `json:"refund_enabled"` + AllowUserRefund *bool `json:"allow_user_refund"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index 68f9c697..75d75b2f 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -17,6 +17,19 @@ import ( // --- Refund Flow --- +// getOrderProviderInstance looks up the provider instance that processed this order. +// Returns nil, nil for legacy orders without provider_instance_id. +func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" { + return nil, nil + } + instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + if err != nil { + return nil, nil + } + return s.entClient.PaymentProviderInstance.Get(ctx, instID) +} + func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { @@ -57,6 +70,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int if o.Status != OrderStatusCompleted { return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") } + // Check provider instance allows user refund + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil || inst == nil { + return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order") + } + if !inst.AllowUserRefund { + return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider") + } return o, nil } @@ -69,6 +90,18 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float if !psSliceContains(ok, o.Status) { return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") } + // Check provider instance allows admin refund + inst, instErr := s.getOrderProviderInstance(ctx, o) + if instErr != nil { + slog.Warn("refund: provider instance not found", "orderID", oid, "error", instErr) + } + if inst != nil && !inst.RefundEnabled { + return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider") + } + if inst == nil && instErr == nil { + // Legacy order without provider_instance_id — block refund + return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order") + } if math.IsNaN(amt) || math.IsInf(amt, 0) { return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") } @@ -102,6 +135,15 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float func (s *PaymentService) prepDeduct(ctx context.Context, o *dbent.PaymentOrder, p *RefundPlan, force bool) *RefundResult { if o.OrderType == payment.OrderTypeSubscription { p.DeductionType = payment.DeductionTypeSubscription + if o.SubscriptionGroupID != nil && o.SubscriptionDays != nil { + p.SubDaysToDeduct = *o.SubscriptionDays + sub, err := s.subscriptionSvc.GetActiveSubscription(ctx, o.UserID, *o.SubscriptionGroupID) + if err == nil && sub != nil { + p.SubscriptionID = sub.ID + } else if !force { + return &RefundResult{Success: false, Warning: "cannot find active subscription for deduction, use force", RequireForce: true} + } + } return nil } u, err := s.userRepo.GetByID(ctx, o.UserID) @@ -137,6 +179,21 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref p.BalanceToDeduct = 0 } } + if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 { + if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { + _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct) + if err != nil { + slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) + if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { + s.restoreStatus(ctx, p) + return nil, fmt.Errorf("revoke subscription: %w", revokeErr) + } + } + } else { + slog.Warn("skipping subscription deduction on retry (previous rollback failed)", "orderID", p.OrderID) + p.SubDaysToDeduct = 0 + } + } if err := s.gwRefund(ctx, p); err != nil { return s.handleGwFail(ctx, p, err) } @@ -204,6 +261,13 @@ func (s *PaymentService) RollbackRefund(ctx context.Context, p *RefundPlan, gErr return false } } + if p.DeductionType == payment.DeductionTypeSubscription && p.SubDaysToDeduct > 0 && p.SubscriptionID > 0 { + if _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, p.SubDaysToDeduct); err != nil { + slog.Error("[CRITICAL] subscription rollback failed", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct, "error", err) + s.writeAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED", "admin", map[string]any{"gatewayError": psErrMsg(gErr), "rollbackError": psErrMsg(err), "subDaysDeducted": p.SubDaysToDeduct}) + return false + } + } return true } diff --git a/backend/migrations/103_add_allow_user_refund.sql b/backend/migrations/103_add_allow_user_refund.sql new file mode 100644 index 00000000..79525382 --- /dev/null +++ b/backend/migrations/103_add_allow_user_refund.sql @@ -0,0 +1 @@ +ALTER TABLE payment_provider_instances ADD COLUMN IF NOT EXISTS allow_user_refund BOOLEAN NOT NULL DEFAULT false; diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts index 1389b60f..5cedb107 100644 --- a/frontend/src/api/payment.ts +++ b/frontend/src/api/payment.ts @@ -75,5 +75,10 @@ export const paymentAPI = { /** Request a refund for a completed order */ requestRefund(id: number, data: { reason: string }) { return apiClient.post(`/payment/orders/${id}/refund-request`, data) + }, + + /** Get provider instance IDs that allow user refund */ + getRefundEligibleProviders() { + return apiClient.get<{ provider_instance_ids: string[] }>('/payment/orders/refund-eligible-providers') } } diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue index 9b60cba1..10c1bfea 100644 --- a/frontend/src/components/payment/PaymentProviderDialog.vue +++ b/frontend/src/components/payment/PaymentProviderDialog.vue @@ -32,7 +32,8 @@
- + +
{{ t('admin.settings.payment.paymentMode') }}
@@ -243,6 +244,7 @@ const emit = defineEmits<{ enabled: boolean payment_mode: string refund_enabled: boolean + allow_user_refund: boolean config: Record limits: string }] @@ -258,6 +260,7 @@ const form = reactive({ enabled: true, payment_mode: PAYMENT_MODE_QRCODE, refund_enabled: false, + allow_user_refund: false, }) const config = reactive>({}) const limits = reactive>>({}) @@ -433,6 +436,7 @@ function handleSave() { enabled: form.enabled, payment_mode: form.provider_key === 'easypay' ? form.payment_mode : '', refund_enabled: form.refund_enabled, + allow_user_refund: form.refund_enabled ? form.allow_user_refund : false, config: filteredConfig, limits: serializeLimits(), }) @@ -452,6 +456,7 @@ function reset(defaultKey: string) { form.enabled = true form.payment_mode = defaultKey === 'easypay' ? PAYMENT_MODE_QRCODE : '' form.refund_enabled = false + form.allow_user_refund = false clearConfig() applyDefaults() } @@ -463,6 +468,7 @@ function loadProvider(provider: ProviderInstance) { form.enabled = provider.enabled form.payment_mode = provider.payment_mode || (provider.provider_key === 'easypay' ? PAYMENT_MODE_QRCODE : '') form.refund_enabled = provider.refund_enabled + form.allow_user_refund = provider.allow_user_refund clearConfig() // Pre-fill config from API response (non-sensitive in cleartext, sensitive masked as ••••••••) if (provider.config) { diff --git a/frontend/src/components/payment/PaymentProviderList.vue b/frontend/src/components/payment/PaymentProviderList.vue index e942b8c4..49ebc726 100644 --- a/frontend/src/components/payment/PaymentProviderList.vue +++ b/frontend/src/components/payment/PaymentProviderList.vue @@ -115,7 +115,7 @@ const emit = defineEmits<{ create: [] edit: [provider: ProviderInstance] delete: [provider: ProviderInstance] - toggleField: [provider: ProviderInstance, field: 'enabled' | 'refund_enabled'] + toggleField: [provider: ProviderInstance, field: 'enabled' | 'refund_enabled' | 'allow_user_refund'] toggleType: [provider: ProviderInstance, type: string] reorder: [providers: { id: number; sort_order: number }[]] }>() diff --git a/frontend/src/components/payment/ProviderCard.vue b/frontend/src/components/payment/ProviderCard.vue index 9fc3b0ff..aecc8c8a 100644 --- a/frontend/src/components/payment/ProviderCard.vue +++ b/frontend/src/components/payment/ProviderCard.vue @@ -46,6 +46,7 @@
+
@@ -102,7 +105,7 @@ const { t } = useI18n() const appStore = useAppStore() const saving = ref(false) -const planForm = reactive({ name: '', group_id: null as number | null, description: '', price: 0, original_price: 0, validity_days: 30, validity_unit: 'days', for_sale: true }) +const planForm = reactive({ name: '', group_id: null as number | null, description: '', price: 0, original_price: 0, validity_days: 30, validity_unit: 'days', sort_order: 0, for_sale: true }) const planFeaturesText = ref('') const validityUnitOptions = computed(() => [ @@ -130,10 +133,10 @@ const selectedGroupInfo = computed(() => { watch(() => props.show, (visible) => { if (!visible) return if (props.plan) { - Object.assign(planForm, { name: props.plan.name, group_id: props.plan.group_id, description: props.plan.description, price: props.plan.price, original_price: props.plan.original_price || 0, validity_days: props.plan.validity_days, validity_unit: props.plan.validity_unit || 'days', for_sale: props.plan.for_sale }) + Object.assign(planForm, { name: props.plan.name, group_id: props.plan.group_id, description: props.plan.description, price: props.plan.price, original_price: props.plan.original_price || 0, validity_days: props.plan.validity_days, validity_unit: props.plan.validity_unit || 'days', sort_order: props.plan.sort_order || 0, for_sale: props.plan.for_sale }) planFeaturesText.value = (props.plan.features || []).join('\n') } else { - Object.assign(planForm, { name: '', group_id: null, description: '', price: 0, original_price: 0, validity_days: 30, validity_unit: 'days', for_sale: true }) + Object.assign(planForm, { name: '', group_id: null, description: '', price: 0, original_price: 0, validity_days: 30, validity_unit: 'days', sort_order: 0, for_sale: true }) planFeaturesText.value = '' } }) @@ -149,6 +152,7 @@ function buildPlanPayload() { original_price: planForm.original_price || 0, validity_days: planForm.validity_days, validity_unit: planForm.validity_unit, + sort_order: planForm.sort_order, for_sale: planForm.for_sale, features, } diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index 3c7df572..bc16918c 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -102,10 +102,12 @@ interface ReturnInfo { } const returnInfo = ref(null) +const SUCCESS_STATUSES = new Set(['COMPLETED', 'PAID', 'RECHARGING']) + const isSuccess = computed(() => { // Always prioritize actual order status from backend if (order.value) { - return order.value.status === 'COMPLETED' || order.value.status === 'PAID' + return SUCCESS_STATUSES.has(order.value.status) } // Fallback only when order not loaded if (route.query.status === 'success') return true diff --git a/frontend/src/views/user/UserOrdersView.vue b/frontend/src/views/user/UserOrdersView.vue index 51aacf7d..ea888eb7 100644 --- a/frontend/src/views/user/UserOrdersView.vue +++ b/frontend/src/views/user/UserOrdersView.vue @@ -22,7 +22,7 @@ {{ t('payment.orders.cancel') }} - @@ -102,6 +102,7 @@ const appStore = useAppStore() const loading = ref(false) const actionLoading = ref(false) const orders = ref([]) +const refundEligibleProviders = ref>(new Set()) const currentFilter = ref('') const cancelTargetId = ref(null) const refundTarget = ref(null) @@ -171,5 +172,18 @@ async function confirmRefund() { } } -onMounted(() => fetchOrders()) +function canRequestRefund(order: PaymentOrder): boolean { + if (order.status !== 'COMPLETED') return false + if (!order.provider_instance_id) return false + return refundEligibleProviders.value.has(order.provider_instance_id) +} + +async function loadRefundEligibility() { + try { + const res = await paymentAPI.getRefundEligibleProviders() + refundEligibleProviders.value = new Set(res.data.provider_instance_ids || []) + } catch { /* ignore — default to hiding refund button */ } +} + +onMounted(() => { fetchOrders(); loadRefundEligibility() }) From 58677dd53fc8a0aa56ead0feeb87ac05b8f58018 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 18:34:57 +0800 Subject: [PATCH 81/88] fix: merge 5 PR-related improvements - gateway_handler: pass ParsedRequest to RecordUsage + set in gin.Context - channel_handler: add FeaturesConfig to CRUD (WebSearch channel toggle) - channel_repo: features_config JSONB persistence (Create/Get/Update/List) - security_headers: add Stripe CSP domains (script-src + frame-src) --- .../internal/handler/admin/channel_handler.go | 6 ++ backend/internal/handler/gateway_handler.go | 3 + backend/internal/repository/channel_repo.go | 61 ++++++++++++++----- .../server/middleware/security_headers.go | 13 +++- 4 files changed, 67 insertions(+), 16 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 88d27c47..9151d018 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -35,6 +35,7 @@ type createChannelRequest struct { BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` RestrictModels bool `json:"restrict_models"` Features string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"` AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } @@ -49,6 +50,7 @@ type updateChannelRequest struct { BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` RestrictModels *bool `json:"restrict_models"` Features *string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"` AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } @@ -93,6 +95,7 @@ type channelResponse struct { BillingModelSource string `json:"billing_model_source"` RestrictModels bool `json:"restrict_models"` Features string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` GroupIDs []int64 `json:"group_ids"` ModelPricing []channelModelPricingResponse `json:"model_pricing"` ModelMapping map[string]map[string]string `json:"model_mapping"` @@ -148,6 +151,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { Status: ch.Status, RestrictModels: ch.RestrictModels, Features: ch.Features, + FeaturesConfig: ch.FeaturesConfig, GroupIDs: ch.GroupIDs, ModelMapping: ch.ModelMapping, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), @@ -379,6 +383,7 @@ func (h *ChannelHandler) Create(c *gin.Context) { BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, Features: req.Features, + FeaturesConfig: req.FeaturesConfig, ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, AccountStatsPricingRules: statsRules, }) @@ -414,6 +419,7 @@ func (h *ChannelHandler) Update(c *gin.Context) { BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, Features: req.Features, + FeaturesConfig: req.FeaturesConfig, ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, } if req.ModelPricing != nil { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 8ec54420..30065463 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -473,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: apiKey, User: apiKey.User, Account: account, @@ -675,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 转发请求 - 根据账号平台分流 + c.Set("parsed_request", parsedReq) var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { @@ -813,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: currentAPIKey, User: currentAPIKey.User, Account: account, diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 583ce895..2cb90aab 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel if err != nil { return err } + featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) + if err != nil { + return err + } err = tx.QueryRowContext(ctx, - `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) if err != nil { if isUniqueViolation(err) { @@ -80,11 +84,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { ch := &service.Channel{} - var modelMappingJSON []byte + var modelMappingJSON, featuresConfigJSON []byte err := r.db.QueryRowContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels WHERE id = $1`, id, - ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt) + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt) if err == sql.ErrNoRows { return nil, service.ErrChannelNotFound } @@ -92,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha return nil, fmt.Errorf("get channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) groupIDs, err := r.GetGroupIDs(ctx, id) if err != nil { @@ -120,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel if err != nil { return err } + featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) + if err != nil { + return err + } result, err := tx.ExecContext(ctx, - `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW() - WHERE id = $9`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW() + WHERE id = $10`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID, ) if err != nil { if isUniqueViolation(err) { @@ -207,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( - `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`, whereClause, channelListOrderBy(params), argIdx, argIdx+1, ) @@ -223,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON, featuresConfigJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -298,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`, ) if err != nil { return nil, fmt.Errorf("query all channels: %w", err) @@ -309,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON, featuresConfigJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -488,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string { return m } +func marshalFeaturesConfig(m map[string]any) ([]byte, error) { + if len(m) == 0 { + return []byte("{}"), nil + } + data, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal features_config: %w", err) + } + return data, nil +} + +func unmarshalFeaturesConfig(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return nil + } + return m +} + // GetGroupPlatforms 批量查询分组 ID 对应的平台 func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { if len(groupIDs) == 0 { diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 73210bfc..7021ab2e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -18,6 +18,8 @@ const ( NonceTemplate = "__CSP_NONCE__" // CloudflareInsightsDomain is the domain for Cloudflare Web Analytics CloudflareInsightsDomain = "https://static.cloudflareinsights.com" + // StripeDomain is the domain for Stripe.js SDK + StripeDomain = "https://*.stripe.com" ) // GenerateNonce generates a cryptographically secure random nonce. @@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool { strings.HasPrefix(path, "/responses") } -// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. -// This allows the application to work correctly even if the config file has an older CSP policy. +// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights, +// and Stripe.js domains. This allows the application to work correctly even if the +// config file has an older CSP policy. func enhanceCSPPolicy(policy string) string { // Add nonce placeholder to script-src if not present if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") { @@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string { policy = addToDirective(policy, "script-src", CloudflareInsightsDomain) } + // Add Stripe.js domain to script-src and frame-src if not present + if !strings.Contains(policy, "stripe.com") { + policy = addToDirective(policy, "script-src", StripeDomain) + policy = addToDirective(policy, "frame-src", StripeDomain) + } + return policy } From c14d739360de12741920c01ca4bbd140559a90d1 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 18:41:09 +0800 Subject: [PATCH 82/88] fix: resolve 3 code review issues in allow_user_refund 1. PrepareRefund: block refund on provider instance lookup failure instead of silently skipping permission check (medium severity) 2. UpdateProviderInstance: allow enabling refund_enabled and allow_user_refund in the same request by checking req.RefundEnabled value before falling back to DB read 3. ExecuteRefund: only revoke subscription on ErrAdjustWouldExpire, abort on other errors (DB failure, not found) instead of unconditionally revoking --- .../service/payment_config_providers.go | 14 ++++++++--- backend/internal/service/payment_refund.go | 25 +++++++++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 0f7cb99a..3c406b45 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -231,10 +231,18 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } } if req.AllowUserRefund != nil { - // Only allow enabling when refund_enabled is true + // Only allow enabling when refund_enabled is (or will be) true if *req.AllowUserRefund { - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err == nil && inst.RefundEnabled { + refundEnabled := false + if req.RefundEnabled != nil { + refundEnabled = *req.RefundEnabled + } else { + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err == nil { + refundEnabled = inst.RefundEnabled + } + } + if refundEnabled { u.SetAllowUserRefund(true) } } else { diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index 75d75b2f..99468433 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "log/slog" "math" @@ -93,15 +94,16 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float // Check provider instance allows admin refund inst, instErr := s.getOrderProviderInstance(ctx, o) if instErr != nil { - slog.Warn("refund: provider instance not found", "orderID", oid, "error", instErr) + slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr) + return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order") } - if inst != nil && !inst.RefundEnabled { - return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider") - } - if inst == nil && instErr == nil { + if inst == nil { // Legacy order without provider_instance_id — block refund return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order") } + if !inst.RefundEnabled { + return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider") + } if math.IsNaN(amt) || math.IsInf(amt, 0) { return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") } @@ -183,10 +185,17 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct) if err != nil { - slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) - if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { + if errors.Is(err, ErrAdjustWouldExpire) { + // Deduction would expire the subscription — revoke it entirely + slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) + if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { + s.restoreStatus(ctx, p) + return nil, fmt.Errorf("revoke subscription: %w", revokeErr) + } + } else { + // Other errors (DB failure, not found) — abort refund s.restoreStatus(ctx, p) - return nil, fmt.Errorf("revoke subscription: %w", revokeErr) + return nil, fmt.Errorf("deduct subscription days: %w", err) } } } else { From 63f539b3828eccbe2bf416458447acb89c6df8bd Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 14 Apr 2026 19:29:37 +0800 Subject: [PATCH 83/88] fix: merge general improvements from release branch Backend: - gateway_handler: pass subject.UserID instead of int64(0) for user-level routing - setting_handler: add missing BalanceLowNotifyRechargeURL to UpdateSettings response - openai_gateway_service: use applyAccountStatsCost for account stats pricing integration - embed_on: add local file override (data/public/) for embedded frontend assets Frontend: - useTableSelection: add batchUpdate method for batch operations - AccountsView: virtual scrolling params, Set-based isSelected, swipe virtualization - ProxiesView: add batchUpdate to selection and swipe-select - BulkEditAccountModal: fix submit handler to prevent event object as argument - SettingsView: move payload construction outside try block - i18n: add general translation keys (saved, deleted, view, validation, allowUserRefund) - api/client: reorder error fields for consistency - stores/payment: clarify pollOrderStatus JSDoc --- .../internal/handler/admin/setting_handler.go | 1 + backend/internal/handler/gateway_handler.go | 2 +- .../service/openai_gateway_service.go | 11 +--- backend/internal/web/embed_on.go | 65 ++++++++++++++++--- frontend/src/api/client.ts | 2 +- .../account/BulkEditAccountModal.vue | 2 +- frontend/src/composables/useTableSelection.ts | 9 ++- frontend/src/i18n/locales/en.ts | 7 ++ frontend/src/i18n/locales/zh.ts | 6 ++ frontend/src/stores/payment.ts | 2 +- frontend/src/views/admin/AccountsView.vue | 26 ++++++-- frontend/src/views/admin/ProxiesView.vue | 6 +- frontend/src/views/admin/SettingsView.vue | 11 ++-- 13 files changed, 114 insertions(+), 36 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 9b49150c..29c97b4b 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1071,6 +1071,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableCCHSigning: updatedSettings.EnableCCHSigning, BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), PaymentEnabled: updatedPaymentCfg.Enabled, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 30065463..f5eff8c9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -522,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0)) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 9a6fbb8f..6087b7b6 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4575,14 +4575,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) if apiKey.GroupID != nil { - statsModel := result.UpstreamModel - if statsModel == "" { - statsModel = result.Model - } - usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, s.channelService, s.billingService, - account.ID, *apiKey.GroupID, statsModel, - tokens, 1, cost.TotalCost, + applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService, + account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model, + tokens, cost.TotalCost, ) } diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index ad5ac7d8..89d09eef 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -10,6 +10,8 @@ import ( "io" "io/fs" "net/http" + "os" + "path/filepath" "strings" "time" @@ -32,11 +34,12 @@ type PublicSettingsProvider interface { // FrontendServer serves the embedded frontend with settings injection type FrontendServer struct { - distFS fs.FS - fileServer http.Handler - baseHTML []byte - cache *HTMLCache - settings PublicSettingsProvider + distFS fs.FS + fileServer http.Handler + baseHTML []byte + cache *HTMLCache + settings PublicSettingsProvider + overrideDir string // local file override directory } // NewFrontendServer creates a new frontend server with settings injection @@ -62,11 +65,12 @@ func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer cache.SetBaseHTML(baseHTML) return &FrontendServer{ - distFS: distFS, - fileServer: http.FileServer(http.FS(distFS)), - baseHTML: baseHTML, - cache: cache, - settings: settingsProvider, + distFS: distFS, + fileServer: http.FileServer(http.FS(distFS)), + baseHTML: baseHTML, + cache: cache, + settings: settingsProvider, + overrideDir: filepath.Join("data", "public"), }, nil } @@ -99,6 +103,11 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc { return } + // Try local override first + if s.tryServeOverride(c, cleanPath) { + return + } + // Serve static files normally s.fileServer.ServeHTTP(c.Writer, c.Request) c.Abort() @@ -114,6 +123,22 @@ func (s *FrontendServer) fileExists(path string) bool { return true } +// tryServeOverride checks if a local override file exists and serves it. +// Files in overrideDir take precedence over embedded files. +func (s *FrontendServer) tryServeOverride(c *gin.Context, cleanPath string) bool { + if s.overrideDir == "" { + return false + } + filePath := filepath.Join(s.overrideDir, filepath.Clean("/"+cleanPath)) + info, err := os.Stat(filePath) + if err != nil || info.IsDir() { + return false + } + c.File(filePath) + c.Abort() + return true +} + func (s *FrontendServer) serveIndexHTML(c *gin.Context) { // Get nonce from context (generated by SecurityHeaders middleware) nonce := middleware.GetNonceFromContext(c) @@ -226,6 +251,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { panic("failed to get dist subdirectory: " + err.Error()) } fileServer := http.FileServer(http.FS(distFS)) + overrideDir := filepath.Join("data", "public") return func(c *gin.Context) { path := c.Request.URL.Path @@ -242,6 +268,10 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { if file, err := distFS.Open(cleanPath); err == nil { _ = file.Close() + // Try local override first + if tryServeOverrideFile(c, overrideDir, cleanPath) { + return + } fileServer.ServeHTTP(c.Writer, c.Request) c.Abort() return @@ -251,6 +281,21 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { } } +// tryServeOverrideFile is a standalone version of tryServeOverride for legacy usage. +func tryServeOverrideFile(c *gin.Context, overrideDir, cleanPath string) bool { + if overrideDir == "" { + return false + } + filePath := filepath.Join(overrideDir, filepath.Clean("/"+cleanPath)) + info, err := os.Stat(filePath) + if err != nil || info.IsDir() { + return false + } + c.File(filePath) + c.Abort() + return true +} + func shouldBypassEmbeddedFrontend(path string) bool { trimmed := strings.TrimSpace(path) return strings.HasPrefix(trimmed, "/api/") || diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 2908c6b1..8a586902 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -270,9 +270,9 @@ apiClient.interceptors.response.use( return Promise.reject({ status, code: apiData.code, + reason: apiData.reason, error: apiData.error, message: apiData.message || apiData.detail || error.message, - reason: apiData.reason, metadata: apiData.metadata, }) } diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 2934fbd9..5461015b 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -5,7 +5,7 @@ width="wide" @close="handleClose" > -
+

diff --git a/frontend/src/composables/useTableSelection.ts b/frontend/src/composables/useTableSelection.ts index a65144a9..f0e096ff 100644 --- a/frontend/src/composables/useTableSelection.ts +++ b/frontend/src/composables/useTableSelection.ts @@ -76,6 +76,12 @@ export function useTableSelection({ rows, getId }: UseTableSelectionOptions) => void) => { + const draft = new Set(selectedSet.value) + updater(draft) + replaceSelectedSet(draft) + } + const selectVisible = () => { toggleVisible(true) } @@ -93,6 +99,7 @@ export function useTableSelection({ rows, getId }: UseTableSelectionOptions { return response.data } - /** Poll order status by ID */ + /** Poll order status by ID (read-only, no upstream check) */ async function pollOrderStatus(orderId: number): Promise { try { const response = await paymentAPI.getOrder(orderId) diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index d7fae112..4fec956b 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -144,6 +144,7 @@