mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 07:22:13 +08:00
Compare commits
13 Commits
v0.1.33
...
revert-114
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30326cf267 | ||
|
|
8d252303fc | ||
|
|
312cc00d21 | ||
|
|
8e55ee0e2c | ||
|
|
2270a54ff6 | ||
|
|
bb7ade265d | ||
|
|
c5b792add5 | ||
|
|
2ccdc2b8ef | ||
|
|
c1e25b7ecf | ||
|
|
35b768b719 | ||
|
|
0b6371174e | ||
|
|
15e676e9cd | ||
|
|
2c35f0276f |
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -31,7 +30,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
wire.Build(
|
||||
// Infrastructure layer ProviderSets
|
||||
config.ProviderSet,
|
||||
infrastructure.ProviderSet,
|
||||
|
||||
// Business layer ProviderSets
|
||||
repository.ProviderSet,
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -35,18 +34,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client, err := infrastructure.ProvideEnt(configConfig)
|
||||
client, err := repository.ProvideEnt(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db, err := infrastructure.ProvideSQLDB(client)
|
||||
db, err := repository.ProvideSQLDB(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userRepository := repository.NewUserRepository(client, db)
|
||||
settingRepository := repository.NewSettingRepository(client)
|
||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||
redisClient := infrastructure.ProvideRedis(configConfig)
|
||||
redisClient := repository.ProvideRedis(configConfig)
|
||||
emailCache := repository.NewEmailCache(redisClient)
|
||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||
@@ -109,7 +108,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService)
|
||||
updateCache := repository.NewUpdateCache(redisClient)
|
||||
gitHubReleaseClient := repository.NewGitHubReleaseClient()
|
||||
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||
|
||||
@@ -10,15 +10,17 @@ import (
|
||||
|
||||
// SettingHandler 系统设置处理器
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler {
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +110,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// Turnstile 参数验证
|
||||
if req.TurnstileEnabled {
|
||||
// 检查必填字段
|
||||
if req.TurnstileSiteKey == "" {
|
||||
response.BadRequest(c, "Turnstile Site Key is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.TurnstileSecretKey == "" {
|
||||
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前设置,检查参数是否有变化
|
||||
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
|
||||
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
|
||||
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
|
||||
if siteKeyChanged || secretKeyChanged {
|
||||
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// ProviderSet 是基础设施层的 Wire 依赖提供者集合。
|
||||
//
|
||||
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
|
||||
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
|
||||
//
|
||||
// 包含的提供者:
|
||||
// - ProvideEnt: 提供 Ent ORM 客户端
|
||||
// - ProvideSQLDB: 提供底层 SQL 数据库连接
|
||||
// - ProvideRedis: 提供 Redis 客户端
|
||||
var ProviderSet = wire.NewSet(
|
||||
ProvideEnt,
|
||||
ProvideSQLDB,
|
||||
ProvideRedis,
|
||||
)
|
||||
|
||||
// ProvideEnt 为依赖注入提供 Ent 客户端。
|
||||
//
|
||||
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
|
||||
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
|
||||
//
|
||||
// 依赖:config.Config
|
||||
// 提供:*ent.Client
|
||||
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
|
||||
client, _, err := InitEnt(cfg)
|
||||
return client, err
|
||||
}
|
||||
|
||||
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
|
||||
//
|
||||
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
|
||||
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
|
||||
//
|
||||
// 设计说明:
|
||||
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
|
||||
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
|
||||
//
|
||||
// 依赖:*ent.Client
|
||||
// 提供:*sql.DB
|
||||
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("nil ent client")
|
||||
}
|
||||
// 从 Ent 客户端获取底层驱动
|
||||
drv, ok := client.Driver().(*entsql.Driver)
|
||||
if !ok {
|
||||
return nil, errors.New("ent driver does not expose *sql.DB")
|
||||
}
|
||||
// 返回驱动持有的 sql.DB 实例
|
||||
return drv.DB(), nil
|
||||
}
|
||||
|
||||
// ProvideRedis 为依赖注入提供 Redis 客户端。
|
||||
//
|
||||
// Redis 用于:
|
||||
// - 分布式锁(如并发控制)
|
||||
// - 缓存(如用户会话、API 响应缓存)
|
||||
// - 速率限制
|
||||
// - 实时统计数据
|
||||
//
|
||||
// 依赖:config.Config
|
||||
// 提供:*redis.Client
|
||||
func ProvideRedis(cfg *config.Config) *redis.Client {
|
||||
return InitRedis(cfg)
|
||||
}
|
||||
@@ -37,8 +37,19 @@ type ClaudeMetadata struct {
|
||||
}
|
||||
|
||||
// ClaudeTool Claude 工具定义
|
||||
// 支持两种格式:
|
||||
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
|
||||
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
|
||||
type ClaudeTool struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"` // 标准格式使用
|
||||
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
|
||||
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
|
||||
}
|
||||
|
||||
// CustomToolSpec MCP custom 工具规格
|
||||
type CustomToolSpec struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]any `json:"input_schema"`
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package antigravity
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -205,6 +206,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
||||
if block.Signature != "" {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if !allowDummyThought {
|
||||
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block
|
||||
log.Printf("Warning: skipping thinking block without signature for Claude model")
|
||||
continue
|
||||
} else {
|
||||
// Gemini 模型使用 dummy signature
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
@@ -379,12 +387,40 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
// 普通工具
|
||||
var funcDecls []GeminiFunctionDecl
|
||||
for _, tool := range tools {
|
||||
// 跳过无效工具名称
|
||||
if tool.Name == "" {
|
||||
log.Printf("Warning: skipping tool with empty name")
|
||||
continue
|
||||
}
|
||||
|
||||
var description string
|
||||
var inputSchema map[string]any
|
||||
|
||||
// 检查是否为 custom 类型工具 (MCP)
|
||||
if tool.Type == "custom" && tool.Custom != nil {
|
||||
// Custom 格式: 从 custom 字段获取 description 和 input_schema
|
||||
description = tool.Custom.Description
|
||||
inputSchema = tool.Custom.InputSchema
|
||||
} else {
|
||||
// 标准格式: 从顶层字段获取
|
||||
description = tool.Description
|
||||
inputSchema = tool.InputSchema
|
||||
}
|
||||
|
||||
// 清理 JSON Schema
|
||||
params := cleanJSONSchema(tool.InputSchema)
|
||||
params := cleanJSONSchema(inputSchema)
|
||||
|
||||
// 为 nil schema 提供默认值
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "OBJECT",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
funcDecls = append(funcDecls, GeminiFunctionDecl{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusForbidden,
|
||||
wantBody: Response{
|
||||
@@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "bad_request_error",
|
||||
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusBadRequest,
|
||||
wantBody: Response{
|
||||
@@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "unauthorized_error",
|
||||
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusUnauthorized,
|
||||
wantBody: Response{
|
||||
@@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "not_found_error",
|
||||
err: infraerrors.NotFound("NOT_FOUND", "not found"),
|
||||
err: errors2.NotFound("NOT_FOUND", "not found"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusNotFound,
|
||||
wantBody: Response{
|
||||
@@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "conflict_error",
|
||||
err: infraerrors.Conflict("CONFLICT", "conflict"),
|
||||
err: errors2.Conflict("CONFLICT", "conflict"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusConflict,
|
||||
wantBody: Response{
|
||||
@@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: infraerrors.UnknownMessage,
|
||||
Message: errors2.UnknownMessage,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package infrastructure
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
@@ -1,4 +1,4 @@
|
||||
package infrastructure
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
@@ -1,6 +1,6 @@
|
||||
// Package infrastructure 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package infrastructure
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
|
||||
log.Printf("failed to open sql db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil {
|
||||
if err := ApplyMigrations(ctx, integrationDB); err != nil {
|
||||
log.Printf("failed to apply db migrations: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package infrastructure
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
|
||||
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
|
||||
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB))
|
||||
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
|
||||
|
||||
// schema_migrations should have at least the current migration set.
|
||||
var applied int
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package infrastructure
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
@@ -1,4 +1,4 @@
|
||||
package infrastructure
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
// IncrementUsage 原子性地累加用量并校验限额。
|
||||
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。
|
||||
// 当更新失败时,会执行额外查询确定具体超出的限额类型。
|
||||
// IncrementUsage 原子性地累加订阅用量。
|
||||
// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
|
||||
// 此处仅负责记录实际消费,确保消费数据的完整性。
|
||||
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加
|
||||
// NULL 限额表示无限制
|
||||
const atomicUpdateSQL = `
|
||||
const updateSQL = `
|
||||
UPDATE user_subscriptions us
|
||||
SET
|
||||
daily_usage_usd = us.daily_usage_usd + $1,
|
||||
@@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
||||
AND us.deleted_at IS NULL
|
||||
AND us.group_id = g.id
|
||||
AND g.deleted_at IS NULL
|
||||
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
|
||||
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
|
||||
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
|
||||
`
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id)
|
||||
result, err := client.ExecContext(ctx, updateSQL, costUSD, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
||||
}
|
||||
|
||||
if affected > 0 {
|
||||
return nil // 更新成功
|
||||
return nil
|
||||
}
|
||||
|
||||
// affected == 0:可能是订阅不存在、分组已删除、或限额超出
|
||||
// 执行额外查询确定具体原因
|
||||
return r.checkIncrementFailureReason(ctx, id, costUSD)
|
||||
}
|
||||
|
||||
// checkIncrementFailureReason 查询更新失败的具体原因
|
||||
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
|
||||
const checkSQL = `
|
||||
SELECT
|
||||
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
|
||||
WHEN g.id IS NULL THEN 'subscription_not_found'
|
||||
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
|
||||
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
|
||||
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
|
||||
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
|
||||
ELSE 'unknown'
|
||||
END AS reason
|
||||
FROM user_subscriptions us
|
||||
LEFT JOIN groups g ON us.group_id = g.id
|
||||
WHERE us.id = $2
|
||||
`
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
var reason string
|
||||
if err := rows.Scan(&reason); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "subscription_not_found", "subscription_deleted", "group_deleted":
|
||||
return service.ErrSubscriptionNotFound
|
||||
case "daily_exceeded":
|
||||
return service.ErrDailyLimitExceeded
|
||||
case "weekly_exceeded":
|
||||
return service.ErrWeeklyLimitExceeded
|
||||
case "monthly_exceeded":
|
||||
return service.ErrMonthlyLimitExceeded
|
||||
default:
|
||||
// unknown 情况理论上不应发生,但作为兜底返回
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
// affected == 0:订阅不存在或已删除
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
|
||||
@@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||
}
|
||||
|
||||
// --- 限额检查与软删除过滤测试 ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
create := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
SetSubscriptionType(service.SubscriptionTypeSubscription)
|
||||
|
||||
if daily != nil {
|
||||
create.SetDailyLimitUsd(*daily)
|
||||
}
|
||||
if weekly != nil {
|
||||
create.SetWeeklyLimitUsd(*weekly)
|
||||
}
|
||||
if monthly != nil {
|
||||
create.SetMonthlyLimitUsd(*monthly)
|
||||
}
|
||||
|
||||
g, err := create.Save(s.ctx)
|
||||
s.Require().NoError(err, "create group with limits")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() {
|
||||
user := s.mustCreateUser("dailylimit@test.com", service.RoleUser)
|
||||
dailyLimit := 10.0
|
||||
group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 先增加 9.0,应该成功
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0)
|
||||
s.Require().NoError(err, "first increment should succeed")
|
||||
|
||||
// 再增加 2.0,会超过 10.0 限额,应该失败
|
||||
err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0)
|
||||
s.Require().Error(err, "should fail when daily limit exceeded")
|
||||
s.Require().ErrorIs(err, service.ErrDailyLimitExceeded)
|
||||
|
||||
// 验证用量没有变化
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() {
|
||||
user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser)
|
||||
weeklyLimit := 50.0
|
||||
group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 增加 45.0,应该成功
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0)
|
||||
s.Require().NoError(err, "first increment should succeed")
|
||||
|
||||
// 再增加 10.0,会超过 50.0 限额,应该失败
|
||||
err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
|
||||
s.Require().Error(err, "should fail when weekly limit exceeded")
|
||||
s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() {
|
||||
user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser)
|
||||
monthlyLimit := 100.0
|
||||
group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 增加 90.0,应该成功
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0)
|
||||
s.Require().NoError(err, "first increment should succeed")
|
||||
|
||||
// 再增加 20.0,会超过 100.0 限额,应该失败
|
||||
err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0)
|
||||
s.Require().Error(err, "should fail when monthly limit exceeded")
|
||||
s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() {
|
||||
user := s.mustCreateUser("nolimits@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 应该可以增加任意金额
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0)
|
||||
s.Require().NoError(err, "should succeed without limits")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() {
|
||||
user := s.mustCreateUser("exactlimit@test.com", service.RoleUser)
|
||||
dailyLimit := 10.0
|
||||
group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 正好达到限额应该成功
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
|
||||
s.Require().NoError(err, "should succeed at exact limit")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6)
|
||||
}
|
||||
// --- 软删除过滤测试 ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
|
||||
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
|
||||
@@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
|
||||
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额
|
||||
group := s.mustCreateGroup("g-concurrent")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
const numGoroutines = 10
|
||||
@@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
|
||||
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() {
|
||||
user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser)
|
||||
dailyLimit := 5.0
|
||||
group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑
|
||||
// 尝试增加 10 次,每次 1.0,但限额只有 5.0
|
||||
const numAttempts = 10
|
||||
const incrementPerAttempt = 1.0
|
||||
|
||||
successCount := 0
|
||||
for i := 0; i < numAttempts; i++ {
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt)
|
||||
if err == nil {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
// 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额)
|
||||
s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)")
|
||||
|
||||
// 验证最终用量等于限额
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
|
||||
baseClient := testEntClient(s.T())
|
||||
tx, err := baseClient.Tx(context.Background())
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/google/wire"
|
||||
@@ -47,4 +52,58 @@ var ProviderSet = wire.NewSet(
|
||||
NewOpenAIOAuthClient,
|
||||
NewGeminiOAuthClient,
|
||||
NewGeminiCliCodeAssistClient,
|
||||
|
||||
ProvideEnt,
|
||||
ProvideSQLDB,
|
||||
ProvideRedis,
|
||||
)
|
||||
|
||||
// ProvideEnt 为依赖注入提供 Ent 客户端。
|
||||
//
|
||||
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
|
||||
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
|
||||
//
|
||||
// 依赖:config.Config
|
||||
// 提供:*ent.Client
|
||||
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
|
||||
client, _, err := InitEnt(cfg)
|
||||
return client, err
|
||||
}
|
||||
|
||||
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
|
||||
//
|
||||
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
|
||||
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
|
||||
//
|
||||
// 设计说明:
|
||||
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
|
||||
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
|
||||
//
|
||||
// 依赖:*ent.Client
|
||||
// 提供:*sql.DB
|
||||
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("nil ent client")
|
||||
}
|
||||
// 从 Ent 客户端获取底层驱动
|
||||
drv, ok := client.Driver().(*entsql.Driver)
|
||||
if !ok {
|
||||
return nil, errors.New("ent driver does not expose *sql.DB")
|
||||
}
|
||||
// 返回驱动持有的 sql.DB 实例
|
||||
return drv.DB(), nil
|
||||
}
|
||||
|
||||
// ProvideRedis 为依赖注入提供 Redis 客户端。
|
||||
//
|
||||
// Redis 用于:
|
||||
// - 分布式锁(如并发控制)
|
||||
// - 缓存(如用户会话、API 响应缓存)
|
||||
// - 速率限制
|
||||
// - 实时统计数据
|
||||
//
|
||||
// 依赖:config.Config
|
||||
// 提供:*redis.Client
|
||||
func ProvideRedis(cfg *config.Config) *redis.Client {
|
||||
return InitRedis(cfg)
|
||||
}
|
||||
|
||||
@@ -385,7 +385,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
||||
|
||||
jwtAuth := func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
|
||||
@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
subscriptionType = SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
// 限额字段:0 和 nil 都表示"无限制"
|
||||
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: input.DailyLimitUSD,
|
||||
WeeklyLimitUSD: input.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: input.MonthlyLimitUSD,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
|
||||
func normalizeLimit(limit *float64) *float64 {
|
||||
if limit == nil || *limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.SubscriptionType != "" {
|
||||
group.SubscriptionType = input.SubscriptionType
|
||||
}
|
||||
// 限额字段支持设置为nil(清除限额)或具体值
|
||||
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
|
||||
if input.DailyLimitUSD != nil {
|
||||
group.DailyLimitUSD = input.DailyLimitUSD
|
||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||
}
|
||||
if input.WeeklyLimitUSD != nil {
|
||||
group.WeeklyLimitUSD = input.WeeklyLimitUSD
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
}
|
||||
if input.MonthlyLimitUSD != nil {
|
||||
group.MonthlyLimitUSD = input.MonthlyLimitUSD
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查账户是否需要刷新
|
||||
// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置
|
||||
// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
|
||||
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
@@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Until(*expiresAt) < antigravityRefreshWindow
|
||||
timeUntilExpiry := time.Until(*expiresAt)
|
||||
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
|
||||
if needsRefresh {
|
||||
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
|
||||
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
|
||||
}
|
||||
return needsRefresh
|
||||
}
|
||||
|
||||
// Refresh 执行 token 刷新
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -2245,12 +2245,40 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := tm["name"].(string)
|
||||
desc, _ := tm["description"].(string)
|
||||
params := tm["input_schema"]
|
||||
|
||||
var name, desc string
|
||||
var params any
|
||||
|
||||
// 检查是否为 custom 类型工具 (MCP)
|
||||
toolType, _ := tm["type"].(string)
|
||||
if toolType == "custom" {
|
||||
// Custom 格式: 从 custom 字段获取 description 和 input_schema
|
||||
custom, ok := tm["custom"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ = tm["name"].(string)
|
||||
desc, _ = custom["description"].(string)
|
||||
params = custom["input_schema"]
|
||||
} else {
|
||||
// 标准格式: 从顶层字段获取
|
||||
name, _ = tm["name"].(string)
|
||||
desc, _ = tm["description"].(string)
|
||||
params = tm["input_schema"]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 为 nil params 提供默认值
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
funcDecls = append(funcDecls, map[string]any{
|
||||
"name": name,
|
||||
"description": desc,
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -490,6 +490,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
|
||||
}
|
||||
|
||||
// CheckUsageLimits 检查使用限额(返回错误如果超限)
|
||||
// 用于中间件的快速预检查,additionalCost 通常为 0
|
||||
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
|
||||
if !sub.CheckDailyLimit(group, additionalCost) {
|
||||
return ErrDailyLimitExceeded
|
||||
|
||||
@@ -5,12 +5,13 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed")
|
||||
ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured")
|
||||
ErrTurnstileInvalidSecretKey = infraerrors.BadRequest("TURNSTILE_INVALID_SECRET_KEY", "invalid turnstile secret key")
|
||||
)
|
||||
|
||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||
@@ -83,3 +84,22 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote
|
||||
func (s *TurnstileService) IsEnabled(ctx context.Context) bool {
|
||||
return s.settingService.IsTurnstileEnabled(ctx)
|
||||
}
|
||||
|
||||
// ValidateSecretKey 验证 Turnstile Secret Key 是否有效
|
||||
func (s *TurnstileService) ValidateSecretKey(ctx context.Context, secretKey string) error {
|
||||
// 发送一个测试token的验证请求来检查secret_key是否有效
|
||||
result, err := s.verifier.VerifyToken(ctx, secretKey, "test-validation", "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate secret key: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否有 invalid-input-secret 错误
|
||||
for _, code := range result.ErrorCodes {
|
||||
if code == "invalid-input-secret" {
|
||||
return ErrTurnstileInvalidSecretKey
|
||||
}
|
||||
}
|
||||
|
||||
// 其他错误(如 invalid-input-response)说明 secret key 是有效的
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
@@ -262,7 +262,7 @@ func initializeDatabase(cfg *SetupConfig) error {
|
||||
|
||||
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
return infrastructure.ApplyMigrations(migrationCtx, db)
|
||||
return repository.ApplyMigrations(migrationCtx, db)
|
||||
}
|
||||
|
||||
func createAdminUser(cfg *SetupConfig) error {
|
||||
|
||||
@@ -69,94 +69,108 @@
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Progress bars -->
|
||||
<!-- Progress bars or Unlimited badge -->
|
||||
<div class="space-y-1.5">
|
||||
<div v-if="subscription.group?.daily_limit_usd" class="flex items-center gap-2">
|
||||
<span class="w-8 flex-shrink-0 text-[10px] text-gray-500">{{
|
||||
t('subscriptionProgress.daily')
|
||||
}}</span>
|
||||
<div class="h-1.5 min-w-0 flex-1 rounded-full bg-gray-200 dark:bg-dark-600">
|
||||
<div
|
||||
class="h-1.5 rounded-full transition-all"
|
||||
:class="
|
||||
getProgressBarClass(
|
||||
subscription.daily_usage_usd,
|
||||
subscription.group?.daily_limit_usd
|
||||
)
|
||||
"
|
||||
:style="{
|
||||
width: getProgressWidth(
|
||||
subscription.daily_usage_usd,
|
||||
subscription.group?.daily_limit_usd
|
||||
)
|
||||
}"
|
||||
></div>
|
||||
</div>
|
||||
<span class="w-24 flex-shrink-0 text-right text-[10px] text-gray-500">
|
||||
{{
|
||||
formatUsage(subscription.daily_usage_usd, subscription.group?.daily_limit_usd)
|
||||
}}
|
||||
<!-- Unlimited subscription badge -->
|
||||
<div
|
||||
v-if="isUnlimited(subscription)"
|
||||
class="flex items-center gap-2 rounded-lg bg-gradient-to-r from-emerald-50 to-teal-50 px-2.5 py-1.5 dark:from-emerald-900/20 dark:to-teal-900/20"
|
||||
>
|
||||
<span class="text-lg text-emerald-600 dark:text-emerald-400">∞</span>
|
||||
<span class="text-xs font-medium text-emerald-700 dark:text-emerald-300">
|
||||
{{ t('subscriptionProgress.unlimited') }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div v-if="subscription.group?.weekly_limit_usd" class="flex items-center gap-2">
|
||||
<span class="w-8 flex-shrink-0 text-[10px] text-gray-500">{{
|
||||
t('subscriptionProgress.weekly')
|
||||
}}</span>
|
||||
<div class="h-1.5 min-w-0 flex-1 rounded-full bg-gray-200 dark:bg-dark-600">
|
||||
<div
|
||||
class="h-1.5 rounded-full transition-all"
|
||||
:class="
|
||||
getProgressBarClass(
|
||||
subscription.weekly_usage_usd,
|
||||
subscription.group?.weekly_limit_usd
|
||||
)
|
||||
"
|
||||
:style="{
|
||||
width: getProgressWidth(
|
||||
subscription.weekly_usage_usd,
|
||||
subscription.group?.weekly_limit_usd
|
||||
)
|
||||
}"
|
||||
></div>
|
||||
<!-- Progress bars for limited subscriptions -->
|
||||
<template v-else>
|
||||
<div v-if="subscription.group?.daily_limit_usd" class="flex items-center gap-2">
|
||||
<span class="w-8 flex-shrink-0 text-[10px] text-gray-500">{{
|
||||
t('subscriptionProgress.daily')
|
||||
}}</span>
|
||||
<div class="h-1.5 min-w-0 flex-1 rounded-full bg-gray-200 dark:bg-dark-600">
|
||||
<div
|
||||
class="h-1.5 rounded-full transition-all"
|
||||
:class="
|
||||
getProgressBarClass(
|
||||
subscription.daily_usage_usd,
|
||||
subscription.group?.daily_limit_usd
|
||||
)
|
||||
"
|
||||
:style="{
|
||||
width: getProgressWidth(
|
||||
subscription.daily_usage_usd,
|
||||
subscription.group?.daily_limit_usd
|
||||
)
|
||||
}"
|
||||
></div>
|
||||
</div>
|
||||
<span class="w-24 flex-shrink-0 text-right text-[10px] text-gray-500">
|
||||
{{
|
||||
formatUsage(subscription.daily_usage_usd, subscription.group?.daily_limit_usd)
|
||||
}}
|
||||
</span>
|
||||
</div>
|
||||
<span class="w-24 flex-shrink-0 text-right text-[10px] text-gray-500">
|
||||
{{
|
||||
formatUsage(subscription.weekly_usage_usd, subscription.group?.weekly_limit_usd)
|
||||
}}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div v-if="subscription.group?.monthly_limit_usd" class="flex items-center gap-2">
|
||||
<span class="w-8 flex-shrink-0 text-[10px] text-gray-500">{{
|
||||
t('subscriptionProgress.monthly')
|
||||
}}</span>
|
||||
<div class="h-1.5 min-w-0 flex-1 rounded-full bg-gray-200 dark:bg-dark-600">
|
||||
<div
|
||||
class="h-1.5 rounded-full transition-all"
|
||||
:class="
|
||||
getProgressBarClass(
|
||||
subscription.monthly_usage_usd,
|
||||
subscription.group?.monthly_limit_usd
|
||||
)
|
||||
"
|
||||
:style="{
|
||||
width: getProgressWidth(
|
||||
subscription.monthly_usage_usd,
|
||||
subscription.group?.monthly_limit_usd
|
||||
)
|
||||
}"
|
||||
></div>
|
||||
<div v-if="subscription.group?.weekly_limit_usd" class="flex items-center gap-2">
|
||||
<span class="w-8 flex-shrink-0 text-[10px] text-gray-500">{{
|
||||
t('subscriptionProgress.weekly')
|
||||
}}</span>
|
||||
<div class="h-1.5 min-w-0 flex-1 rounded-full bg-gray-200 dark:bg-dark-600">
|
||||
<div
|
||||
class="h-1.5 rounded-full transition-all"
|
||||
:class="
|
||||
getProgressBarClass(
|
||||
subscription.weekly_usage_usd,
|
||||
subscription.group?.weekly_limit_usd
|
||||
)
|
||||
"
|
||||
:style="{
|
||||
width: getProgressWidth(
|
||||
subscription.weekly_usage_usd,
|
||||
subscription.group?.weekly_limit_usd
|
||||
)
|
||||
}"
|
||||
></div>
|
||||
</div>
|
||||
<span class="w-24 flex-shrink-0 text-right text-[10px] text-gray-500">
|
||||
{{
|
||||
formatUsage(subscription.weekly_usage_usd, subscription.group?.weekly_limit_usd)
|
||||
}}
|
||||
</span>
|
||||
</div>
|
||||
<span class="w-24 flex-shrink-0 text-right text-[10px] text-gray-500">
|
||||
{{
|
||||
formatUsage(
|
||||
subscription.monthly_usage_usd,
|
||||
subscription.group?.monthly_limit_usd
|
||||
)
|
||||
}}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div v-if="subscription.group?.monthly_limit_usd" class="flex items-center gap-2">
|
||||
<span class="w-8 flex-shrink-0 text-[10px] text-gray-500">{{
|
||||
t('subscriptionProgress.monthly')
|
||||
}}</span>
|
||||
<div class="h-1.5 min-w-0 flex-1 rounded-full bg-gray-200 dark:bg-dark-600">
|
||||
<div
|
||||
class="h-1.5 rounded-full transition-all"
|
||||
:class="
|
||||
getProgressBarClass(
|
||||
subscription.monthly_usage_usd,
|
||||
subscription.group?.monthly_limit_usd
|
||||
)
|
||||
"
|
||||
:style="{
|
||||
width: getProgressWidth(
|
||||
subscription.monthly_usage_usd,
|
||||
subscription.group?.monthly_limit_usd
|
||||
)
|
||||
}"
|
||||
></div>
|
||||
</div>
|
||||
<span class="w-24 flex-shrink-0 text-right text-[10px] text-gray-500">
|
||||
{{
|
||||
formatUsage(
|
||||
subscription.monthly_usage_usd,
|
||||
subscription.group?.monthly_limit_usd
|
||||
)
|
||||
}}
|
||||
</span>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -215,7 +229,19 @@ function getMaxUsagePercentage(sub: UserSubscription): number {
|
||||
return percentages.length > 0 ? Math.max(...percentages) : 0
|
||||
}
|
||||
|
||||
function isUnlimited(sub: UserSubscription): boolean {
|
||||
return (
|
||||
!sub.group?.daily_limit_usd &&
|
||||
!sub.group?.weekly_limit_usd &&
|
||||
!sub.group?.monthly_limit_usd
|
||||
)
|
||||
}
|
||||
|
||||
function getProgressDotClass(sub: UserSubscription): string {
|
||||
// Unlimited subscriptions get a special color
|
||||
if (isUnlimited(sub)) {
|
||||
return 'bg-emerald-500'
|
||||
}
|
||||
const maxPercentage = getMaxUsagePercentage(sub)
|
||||
if (maxPercentage >= 90) return 'bg-red-500'
|
||||
if (maxPercentage >= 70) return 'bg-orange-500'
|
||||
|
||||
@@ -749,6 +749,7 @@ export default {
|
||||
weekly: 'Weekly',
|
||||
monthly: 'Monthly',
|
||||
noLimits: 'No limits configured',
|
||||
unlimited: 'Unlimited',
|
||||
resetNow: 'Resetting soon',
|
||||
windowNotActive: 'Window not active',
|
||||
resetInMinutes: 'Resets in {minutes}m',
|
||||
@@ -1492,7 +1493,8 @@ export default {
|
||||
expiresToday: 'Expires today',
|
||||
expiresTomorrow: 'Expires tomorrow',
|
||||
viewAll: 'View all subscriptions',
|
||||
noSubscriptions: 'No active subscriptions'
|
||||
noSubscriptions: 'No active subscriptions',
|
||||
unlimited: 'Unlimited'
|
||||
},
|
||||
|
||||
// Version Badge
|
||||
@@ -1535,6 +1537,7 @@ export default {
|
||||
expires: 'Expires',
|
||||
noExpiration: 'No expiration',
|
||||
unlimited: 'Unlimited',
|
||||
unlimitedDesc: 'No usage limits on this subscription',
|
||||
daily: 'Daily',
|
||||
weekly: 'Weekly',
|
||||
monthly: 'Monthly',
|
||||
|
||||
@@ -840,6 +840,7 @@ export default {
|
||||
weekly: '每周',
|
||||
monthly: '每月',
|
||||
noLimits: '未配置限额',
|
||||
unlimited: '无限制',
|
||||
resetNow: '即将重置',
|
||||
windowNotActive: '窗口未激活',
|
||||
resetInMinutes: '{minutes} 分钟后重置',
|
||||
@@ -1689,7 +1690,8 @@ export default {
|
||||
expiresToday: '今天到期',
|
||||
expiresTomorrow: '明天到期',
|
||||
viewAll: '查看全部订阅',
|
||||
noSubscriptions: '暂无有效订阅'
|
||||
noSubscriptions: '暂无有效订阅',
|
||||
unlimited: '无限制'
|
||||
},
|
||||
|
||||
// Version Badge
|
||||
@@ -1731,6 +1733,7 @@ export default {
|
||||
expires: '到期时间',
|
||||
noExpiration: '无到期时间',
|
||||
unlimited: '无限制',
|
||||
unlimitedDesc: '该订阅无用量限制',
|
||||
daily: '每日',
|
||||
weekly: '每周',
|
||||
monthly: '每月',
|
||||
|
||||
@@ -202,16 +202,19 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- No Limits -->
|
||||
<!-- No Limits - Unlimited badge -->
|
||||
<div
|
||||
v-if="
|
||||
!row.group?.daily_limit_usd &&
|
||||
!row.group?.weekly_limit_usd &&
|
||||
!row.group?.monthly_limit_usd
|
||||
"
|
||||
class="text-xs text-gray-500"
|
||||
class="flex items-center gap-2 rounded-lg bg-gradient-to-r from-emerald-50 to-teal-50 px-3 py-2 dark:from-emerald-900/20 dark:to-teal-900/20"
|
||||
>
|
||||
{{ t('admin.subscriptions.noLimits') }}
|
||||
<span class="text-lg text-emerald-600 dark:text-emerald-400">∞</span>
|
||||
<span class="text-xs font-medium text-emerald-700 dark:text-emerald-300">
|
||||
{{ t('admin.subscriptions.unlimited') }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -230,18 +230,26 @@
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- No limits configured -->
|
||||
<!-- No limits configured - Unlimited badge -->
|
||||
<div
|
||||
v-if="
|
||||
!subscription.group?.daily_limit_usd &&
|
||||
!subscription.group?.weekly_limit_usd &&
|
||||
!subscription.group?.monthly_limit_usd
|
||||
"
|
||||
class="py-4 text-center"
|
||||
class="flex items-center justify-center rounded-xl bg-gradient-to-r from-emerald-50 to-teal-50 py-6 dark:from-emerald-900/20 dark:to-teal-900/20"
|
||||
>
|
||||
<span class="text-sm text-gray-500 dark:text-dark-400">{{
|
||||
t('userSubscriptions.unlimited')
|
||||
}}</span>
|
||||
<div class="flex items-center gap-3">
|
||||
<span class="text-4xl text-emerald-600 dark:text-emerald-400">∞</span>
|
||||
<div>
|
||||
<p class="text-sm font-medium text-emerald-700 dark:text-emerald-300">
|
||||
{{ t('userSubscriptions.unlimited') }}
|
||||
</p>
|
||||
<p class="text-xs text-emerald-600/70 dark:text-emerald-400/70">
|
||||
{{ t('userSubscriptions.unlimitedDesc') }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user