mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
92 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fbdff4f34f | ||
|
|
0aa480283f | ||
|
|
cd9d31f5f2 | ||
|
|
cbfce49aa1 | ||
|
|
d3e73f1260 | ||
|
|
b5ca6a654c | ||
|
|
94749b12ac | ||
|
|
523fa9f71e | ||
|
|
54636781ea | ||
|
|
5187db5ee5 | ||
|
|
0b9c4ae69e | ||
|
|
0d5a8a95c8 | ||
|
|
9cd97c9e1d | ||
|
|
d521191e87 | ||
|
|
fd78993b91 | ||
|
|
80cce858cb | ||
|
|
0743652d92 | ||
|
|
96bec5c9b1 | ||
|
|
cfeb6b8b14 | ||
|
|
481310dea0 | ||
|
|
ea2821d11d | ||
|
|
7a0de1765f | ||
|
|
35b1bc3753 | ||
|
|
8d38788672 | ||
|
|
c615a4264d | ||
|
|
227d506c53 | ||
|
|
36a86e9ab4 | ||
|
|
f133b051dc | ||
|
|
7af1bdbf4c | ||
|
|
016d7ef645 | ||
|
|
f1e47291cd | ||
|
|
d7e9ae38e4 | ||
|
|
88be981afc | ||
|
|
3f92a43170 | ||
|
|
2101f1d1c8 | ||
|
|
f0f920e49f | ||
|
|
95583fce83 | ||
|
|
254f12543c | ||
|
|
cf8a64528c | ||
|
|
2b79c4e8b7 | ||
|
|
429f38d0c9 | ||
|
|
2714be99a9 | ||
|
|
d851818035 | ||
|
|
576bf4639c | ||
|
|
9db52838b5 | ||
|
|
bfcd9501c2 | ||
|
|
12252c6005 | ||
|
|
2d89f36687 | ||
|
|
3d608c2625 | ||
|
|
739d0ee61e | ||
|
|
22f07a7bb6 | ||
|
|
16eec4eb41 | ||
|
|
ecb2c5353c | ||
|
|
06d5876b02 | ||
|
|
e5a77853b0 | ||
|
|
9780f0fd9d | ||
|
|
3559830882 | ||
|
|
5594680130 | ||
|
|
50855ec15f | ||
|
|
f9f33e7b5c | ||
|
|
1bec35999b | ||
|
|
632318ad33 | ||
|
|
456e8984b0 | ||
|
|
eea949853a | ||
|
|
85fd1e4a2c | ||
|
|
6682d06c99 | ||
|
|
efa470efc7 | ||
|
|
79d1585250 | ||
|
|
2d1a15b196 | ||
|
|
09431cfc0b | ||
|
|
46cb82bac0 | ||
|
|
b2d71da2a2 | ||
|
|
2d6e1d26c0 | ||
|
|
50734c5edc | ||
|
|
040dc27ea5 | ||
|
|
d7090de0e0 | ||
|
|
cab681c7d1 | ||
|
|
01f990a5c9 | ||
|
|
5763f5ced3 | ||
|
|
f79b0f0fad | ||
|
|
34183b527b | ||
|
|
bceed08fc3 | ||
|
|
5deef27e1d | ||
|
|
1ac8b1f03e | ||
|
|
0b30cc2b7e | ||
|
|
03a8ae62e5 | ||
|
|
e36fb98fb9 | ||
|
|
55258bf099 | ||
|
|
dc109827b7 | ||
|
|
71c28e436a | ||
|
|
2bafc28a9b | ||
|
|
aea48ae1ab |
58
.github/workflows/release.yml
vendored
58
.github/workflows/release.yml
vendored
@@ -143,3 +143,61 @@ jobs:
|
||||
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
|
||||
short-description: "Sub2API - AI API Gateway Platform"
|
||||
readme-filepath: ./deploy/DOCKER.md
|
||||
|
||||
# Send Telegram notification
|
||||
- name: Send Telegram Notification
|
||||
env:
|
||||
TELEGRAM_BOT_TOKEN: ${{ secrets.TELEGRAM_BOT_TOKEN }}
|
||||
TELEGRAM_CHAT_ID: ${{ secrets.TELEGRAM_CHAT_ID }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
# 检查必要的环境变量
|
||||
if [ -z "$TELEGRAM_BOT_TOKEN" ] || [ -z "$TELEGRAM_CHAT_ID" ]; then
|
||||
echo "Telegram credentials not configured, skipping notification"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
TAG_NAME=${GITHUB_REF#refs/tags/}
|
||||
VERSION=${TAG_NAME#v}
|
||||
REPO="${{ github.repository }}"
|
||||
DOCKER_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/sub2api"
|
||||
|
||||
# 获取 tag message 内容
|
||||
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
|
||||
|
||||
# 限制消息长度(Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
|
||||
if [ ${#TAG_MESSAGE} -gt 3500 ]; then
|
||||
TAG_MESSAGE="${TAG_MESSAGE:0:3500}..."
|
||||
fi
|
||||
|
||||
# 构建消息内容
|
||||
MESSAGE="🚀 *Sub2API 新版本发布!*"$'\n'$'\n'
|
||||
MESSAGE+="📦 版本号: \`${VERSION}\`"$'\n'$'\n'
|
||||
|
||||
# 添加更新内容
|
||||
if [ -n "$TAG_MESSAGE" ]; then
|
||||
MESSAGE+="${TAG_MESSAGE}"$'\n'$'\n'
|
||||
fi
|
||||
|
||||
MESSAGE+="🐳 *Docker 部署:*"$'\n'
|
||||
MESSAGE+="\`\`\`bash"$'\n'
|
||||
MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
|
||||
MESSAGE+="docker pull ${DOCKER_IMAGE}:latest"$'\n'
|
||||
MESSAGE+="\`\`\`"$'\n'$'\n'
|
||||
MESSAGE+="🔗 *相关链接:*"$'\n'
|
||||
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
|
||||
MESSAGE+="• [Docker Hub](https://hub.docker.com/r/${DOCKER_IMAGE})"$'\n'$'\n'
|
||||
MESSAGE+="#Sub2API #Release #${TAG_NAME//./_}"
|
||||
|
||||
# 发送消息
|
||||
curl -s -X POST "https://api.telegram.org/bot${TELEGRAM_BOT_TOKEN}/sendMessage" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$(jq -n \
|
||||
--arg chat_id "${TELEGRAM_CHAT_ID}" \
|
||||
--arg text "${MESSAGE}" \
|
||||
'{
|
||||
chat_id: $chat_id,
|
||||
text: $text,
|
||||
parse_mode: "Markdown",
|
||||
disable_web_page_preview: true
|
||||
}')"
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -21,6 +21,9 @@ coverage.html
|
||||
# 依赖(使用 go mod)
|
||||
vendor/
|
||||
|
||||
# Go 编译缓存
|
||||
backend/.gocache/
|
||||
|
||||
# ===================
|
||||
# Node.js / Vue 前端
|
||||
# ===================
|
||||
@@ -29,6 +32,7 @@ frontend/node_modules/
|
||||
frontend/dist/
|
||||
*.local
|
||||
*.tsbuildinfo
|
||||
vite.config.d.ts
|
||||
|
||||
# 日志
|
||||
npm-debug.log*
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 1: Frontend Builder
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM node:20-alpine AS frontend-builder
|
||||
FROM node:24-alpine AS frontend-builder
|
||||
|
||||
WORKDIR /app/frontend
|
||||
|
||||
@@ -24,7 +24,7 @@ RUN npm run build
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 2: Backend Builder
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM golang:1.24-alpine AS backend-builder
|
||||
FROM golang:1.25-alpine AS backend-builder
|
||||
|
||||
# Build arguments for version info (set by CI)
|
||||
ARG VERSION=docker
|
||||
|
||||
@@ -23,6 +23,8 @@ linters:
|
||||
desc: "service must not import repository"
|
||||
- pkg: gorm.io/gorm
|
||||
desc: "service must not import gorm"
|
||||
- pkg: github.com/redis/go-redis/v9
|
||||
desc: "service must not import redis"
|
||||
handler-no-repository:
|
||||
list-mode: original
|
||||
files:
|
||||
@@ -30,6 +32,10 @@ linters:
|
||||
deny:
|
||||
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
|
||||
desc: "handler must not import repository"
|
||||
- pkg: gorm.io/gorm
|
||||
desc: "handler must not import gorm"
|
||||
- pkg: github.com/redis/go-redis/v9
|
||||
desc: "handler must not import redis"
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
|
||||
@@ -69,6 +69,7 @@ func provideCleanup(
|
||||
emailQueue *service.EmailQueueService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -99,6 +100,10 @@ func provideCleanup(
|
||||
openaiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"GeminiOAuthService", func() error {
|
||||
geminiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
|
||||
@@ -48,8 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(authService)
|
||||
userService := service.NewUserService(userRepository)
|
||||
authHandler := handler.NewAuthHandler(authService, userService)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||
groupRepository := repository.NewGroupRepository(db)
|
||||
@@ -80,17 +80,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(client)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
proxyHandler := admin.NewProxyHandler(adminService)
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||
@@ -101,7 +107,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||
gatewayCache := repository.NewGatewayCache(client)
|
||||
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
@@ -111,19 +117,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityCache := repository.NewIdentityCache(client)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
|
||||
timingWheelService := service.ProvideTimingWheelService()
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, configConfig)
|
||||
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
v := provideCleanup(db, client, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -153,6 +162,7 @@ func provideCleanup(
|
||||
emailQueue *service.EmailQueueService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -182,6 +192,10 @@ func provideCleanup(
|
||||
openaiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"GeminiOAuthService", func() error {
|
||||
geminiOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
|
||||
@@ -11,23 +11,26 @@ require (
|
||||
github.com/google/wire v0.7.0
|
||||
github.com/imroc/req/v3 v3.56.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/spf13/viper v1.18.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
|
||||
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/zeromicro/go-zero v1.9.4
|
||||
golang.org/x/crypto v0.44.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/term v0.37.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/datatypes v1.2.0
|
||||
gorm.io/driver/postgres v1.5.4
|
||||
gorm.io/gorm v1.25.5
|
||||
)
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
@@ -47,6 +50,7 @@ require (
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/fatih/color v1.18.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
@@ -57,6 +61,7 @@ require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
@@ -64,9 +69,9 @@ require (
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/icholy/digest v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.4 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.7.4 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
@@ -75,7 +80,8 @@ require (
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mdelapenya/tlscert v0.2.0 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
@@ -90,7 +96,7 @@ require (
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
@@ -102,6 +108,7 @@ require (
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/spf13/afero v1.11.0 // indirect
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
@@ -120,7 +127,8 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/atomic v1.10.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
@@ -132,4 +140,5 @@ require (
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gorm.io/driver/mysql v1.5.2 // indirect
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||
@@ -50,6 +52,8 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
@@ -77,10 +81,17 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
|
||||
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -104,12 +115,12 @@ github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
|
||||
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
|
||||
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -133,10 +144,17 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
|
||||
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
|
||||
github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
@@ -166,8 +184,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -175,12 +193,14 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
|
||||
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@@ -195,6 +215,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
|
||||
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
@@ -215,6 +237,7 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
@@ -246,26 +269,30 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
github.com/zeromicro/go-zero v1.9.4 h1:aRLFoISqAYijABtkbliQC5SsI5TbizJpQvoHc9xup8k=
|
||||
github.com/zeromicro/go-zero v1.9.4/go.mod h1:a17JOTch25SWxBcUgJZYps60hygK3pIYdw7nGwlcS38=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0 h1:t6wl9SPayj+c7lEIFgm4ooDBZVb01IhLB4InpomhRw8=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0/go.mod h1:iSDOcsnSA5INXzZtwaBPrKp/lWu/V14Dd+llD0oI2EA=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.24.0 h1:Xw8U6u2f8DK2XAkGRFV7BBLENgnTGX9i4rQRxJf+/vs=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.24.0/go.mod h1:6KW1Fm6R/s6Z3PGXwSJN2K4eT6wQB3vXX6CVnYX9NmM=
|
||||
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
|
||||
go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8=
|
||||
go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
|
||||
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
||||
@@ -288,6 +315,7 @@ golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -319,8 +347,17 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
|
||||
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
|
||||
gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs=
|
||||
gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8=
|
||||
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
||||
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
||||
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
|
||||
gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig=
|
||||
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
|
||||
@@ -18,6 +18,17 @@ type Config struct {
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
|
||||
}
|
||||
|
||||
type GeminiOAuthConfig struct {
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
Scopes string `mapstructure:"scopes"`
|
||||
}
|
||||
|
||||
// TokenRefreshConfig OAuth token自动刷新配置
|
||||
@@ -211,9 +222,16 @@ func setDefaults() {
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
|
||||
// Gemini OAuth - configure via environment variables or config file
|
||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||
// Default: uses Gemini CLI public credentials (set via environment)
|
||||
viper.SetDefault("gemini.oauth.client_id", "")
|
||||
viper.SetDefault("gemini.oauth.client_secret", "")
|
||||
viper.SetDefault("gemini.oauth.scopes", "")
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
|
||||
@@ -2,9 +2,11 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -30,6 +32,7 @@ type AccountHandler struct {
|
||||
adminService service.AdminService
|
||||
oauthService *service.OAuthService
|
||||
openaiOAuthService *service.OpenAIOAuthService
|
||||
geminiOAuthService *service.GeminiOAuthService
|
||||
rateLimitService *service.RateLimitService
|
||||
accountUsageService *service.AccountUsageService
|
||||
accountTestService *service.AccountTestService
|
||||
@@ -42,6 +45,7 @@ func NewAccountHandler(
|
||||
adminService service.AdminService,
|
||||
oauthService *service.OAuthService,
|
||||
openaiOAuthService *service.OpenAIOAuthService,
|
||||
geminiOAuthService *service.GeminiOAuthService,
|
||||
rateLimitService *service.RateLimitService,
|
||||
accountUsageService *service.AccountUsageService,
|
||||
accountTestService *service.AccountTestService,
|
||||
@@ -52,6 +56,7 @@ func NewAccountHandler(
|
||||
adminService: adminService,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
rateLimitService: rateLimitService,
|
||||
accountUsageService: accountUsageService,
|
||||
accountTestService: accountTestService,
|
||||
@@ -102,7 +107,7 @@ type BulkUpdateAccountsRequest struct {
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
type AccountWithConcurrency struct {
|
||||
*model.Account
|
||||
*dto.Account
|
||||
CurrentConcurrency int `json:"current_concurrency"`
|
||||
}
|
||||
|
||||
@@ -137,7 +142,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
result := make([]AccountWithConcurrency, len(accounts))
|
||||
for i := range accounts {
|
||||
result[i] = AccountWithConcurrency{
|
||||
Account: &accounts[i],
|
||||
Account: dto.AccountFromService(&accounts[i]),
|
||||
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
|
||||
}
|
||||
}
|
||||
@@ -160,7 +165,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
// Create handles creating a new account
|
||||
@@ -188,7 +193,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
// Update handles updating an account
|
||||
@@ -222,7 +227,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
// Delete handles deleting an account
|
||||
@@ -345,6 +350,19 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == service.PlatformGemini {
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
newCredentials = h.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
@@ -362,10 +380,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
// Update token-related fields
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = tokenInfo.ExpiresIn
|
||||
newCredentials["expires_at"] = tokenInfo.ExpiresAt
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if strings.TrimSpace(tokenInfo.Scope) != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
}
|
||||
|
||||
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
@@ -376,7 +398,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedAccount)
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// GetStats handles getting account statistics
|
||||
@@ -425,7 +447,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
@@ -801,7 +823,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
// GetAvailableModels handles getting available models for an account
|
||||
@@ -858,6 +880,44 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Gemini accounts
|
||||
if account.IsGemini() {
|
||||
// For OAuth accounts: return default Gemini models
|
||||
if account.IsOAuth() {
|
||||
response.Success(c, geminicli.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: return models based on model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, geminicli.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
var models []geminicli.Model
|
||||
for requestedModel := range mapping {
|
||||
var found bool
|
||||
for _, dm := range geminicli.DefaultModels {
|
||||
if dm.ID == requestedModel {
|
||||
models = append(models, dm)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
models = append(models, geminicli.Model{
|
||||
ID: requestedModel,
|
||||
Type: "model",
|
||||
DisplayName: requestedModel,
|
||||
CreatedAt: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
response.Success(c, models)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
135
backend/internal/handler/admin/gemini_oauth_handler.go
Normal file
135
backend/internal/handler/admin/gemini_oauth_handler.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GeminiOAuthHandler struct {
|
||||
geminiOAuthService *service.GeminiOAuthService
|
||||
}
|
||||
|
||||
func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *GeminiOAuthHandler {
|
||||
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
// GET /api/v1/admin/gemini/oauth/capabilities
|
||||
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
|
||||
cfg := h.geminiOAuthService.GetOAuthConfig()
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
type GeminiGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
|
||||
// 默认为 "code_assist" 以保持向后兼容
|
||||
OAuthType string `json:"oauth_type"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
|
||||
// POST /api/v1/admin/gemini/oauth/auth-url
|
||||
func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
var req GeminiGenerateAuthURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 默认使用 code_assist 以保持向后兼容
|
||||
oauthType := strings.TrimSpace(req.OAuthType)
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
if oauthType != "code_assist" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
|
||||
return
|
||||
}
|
||||
|
||||
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
|
||||
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
|
||||
redirectURI := deriveGeminiRedirectURI(c)
|
||||
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
// Treat missing/invalid OAuth client configuration as a user/config error.
|
||||
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
|
||||
response.BadRequest(c, "Failed to generate auth URL: "+msg)
|
||||
return
|
||||
}
|
||||
response.InternalError(c, "Failed to generate auth URL: "+msg)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
type GeminiExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
|
||||
OAuthType string `json:"oauth_type"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens.
|
||||
// POST /api/v1/admin/gemini/oauth/exchange-code
|
||||
func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
var req GeminiExchangeCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 默认使用 code_assist 以保持向后兼容
|
||||
oauthType := strings.TrimSpace(req.OAuthType)
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
if oauthType != "code_assist" && oauthType != "ai_studio" {
|
||||
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
State: req.State,
|
||||
Code: req.Code,
|
||||
ProxyID: req.ProxyID,
|
||||
OAuthType: oauthType,
|
||||
})
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to exchange code: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
func deriveGeminiRedirectURI(c *gin.Context) string {
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
if origin != "" {
|
||||
return strings.TrimRight(origin, "/") + "/auth/callback"
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" {
|
||||
scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0])
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(c.Request.Host)
|
||||
if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" {
|
||||
host = strings.TrimSpace(strings.Split(xfHost, ",")[0])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s/auth/callback", scheme, host)
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -69,7 +69,11 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, groups, total, page, pageSize)
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Paginated(c, outGroups, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetAll handles getting all active groups without pagination
|
||||
@@ -77,7 +81,7 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
platform := c.Query("platform")
|
||||
|
||||
var groups []model.Group
|
||||
var groups []service.Group
|
||||
var err error
|
||||
|
||||
if platform != "" {
|
||||
@@ -91,7 +95,11 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, groups)
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Success(c, outGroups)
|
||||
}
|
||||
|
||||
// GetByID handles getting a group by ID
|
||||
@@ -109,7 +117,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, group)
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
@@ -137,7 +145,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, group)
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Update handles updating a group
|
||||
@@ -172,7 +180,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, group)
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
}
|
||||
|
||||
// Delete handles deleting a group
|
||||
@@ -229,5 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, keys, total, page, pageSize)
|
||||
outKeys := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -163,7 +164,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedAccount)
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
@@ -224,5 +225,5 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, account)
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -57,7 +58,11 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, proxies, total, page, pageSize)
|
||||
out := make([]dto.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *dto.ProxyFromService(&proxies[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetAll handles getting all active proxies without pagination
|
||||
@@ -72,7 +77,11 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, proxies)
|
||||
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -82,7 +91,11 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, proxies)
|
||||
out := make([]dto.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *dto.ProxyFromService(&proxies[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetByID handles getting a proxy by ID
|
||||
@@ -100,7 +113,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, proxy)
|
||||
response.Success(c, dto.ProxyFromService(proxy))
|
||||
}
|
||||
|
||||
// Create handles creating a new proxy
|
||||
@@ -125,7 +138,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, proxy)
|
||||
response.Success(c, dto.ProxyFromService(proxy))
|
||||
}
|
||||
|
||||
// Update handles updating a proxy
|
||||
@@ -157,7 +170,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, proxy)
|
||||
response.Success(c, dto.ProxyFromService(proxy))
|
||||
}
|
||||
|
||||
// Delete handles deleting a proxy
|
||||
@@ -233,7 +246,11 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, accounts, total, page, pageSize)
|
||||
out := make([]dto.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
out = append(out, *dto.AccountFromService(&accounts[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -29,8 +30,8 @@ type GenerateRedeemCodesRequest struct {
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
|
||||
Value float64 `json:"value" binding:"min=0"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days"` // 订阅类型使用,默认30天
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@@ -47,7 +48,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, codes, total, page, pageSize)
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a redeem code by ID
|
||||
@@ -65,7 +70,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, code)
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
}
|
||||
|
||||
// Generate handles generating new redeem codes
|
||||
@@ -89,7 +94,11 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, codes)
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// Delete handles deleting a redeem code
|
||||
@@ -148,7 +157,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, code)
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
}
|
||||
|
||||
// GetStats handles getting redeem code statistics
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -31,7 +31,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, settings)
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SmtpHost: settings.SmtpHost,
|
||||
SmtpPort: settings.SmtpPort,
|
||||
SmtpUsername: settings.SmtpUsername,
|
||||
SmtpPassword: settings.SmtpPassword,
|
||||
SmtpFrom: settings.SmtpFrom,
|
||||
SmtpFromName: settings.SmtpFromName,
|
||||
SmtpUseTLS: settings.SmtpUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: settings.TurnstileSecretKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
@@ -87,7 +108,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
settings := &model.SystemSettings{
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SmtpHost: req.SmtpHost,
|
||||
@@ -122,7 +143,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updatedSettings)
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SmtpHost: updatedSettings.SmtpHost,
|
||||
SmtpPort: updatedSettings.SmtpPort,
|
||||
SmtpUsername: updatedSettings.SmtpUsername,
|
||||
SmtpPassword: updatedSettings.SmtpPassword,
|
||||
SmtpFrom: updatedSettings.SmtpFrom,
|
||||
SmtpFromName: updatedSettings.SmtpFromName,
|
||||
SmtpUseTLS: updatedSettings.SmtpUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
ApiBaseUrl: updatedSettings.ApiBaseUrl,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocUrl: updatedSettings.DocUrl,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
})
|
||||
}
|
||||
|
||||
// TestSmtpRequest 测试SMTP连接请求
|
||||
|
||||
@@ -3,9 +3,10 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -40,7 +41,7 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S
|
||||
type AssignSubscriptionRequest struct {
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
GroupID int64 `json:"group_id" binding:"required"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // max 100 years
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
@@ -48,13 +49,13 @@ type AssignSubscriptionRequest struct {
|
||||
type BulkAssignSubscriptionRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required,min=1"`
|
||||
GroupID int64 `json:"group_id" binding:"required"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // max 100 years
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// ExtendSubscriptionRequest represents extend subscription request
|
||||
type ExtendSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=1"`
|
||||
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
|
||||
}
|
||||
|
||||
// List handles listing all subscriptions with pagination and filters
|
||||
@@ -82,7 +83,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
|
||||
// GetByID handles getting a subscription by ID
|
||||
@@ -100,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscription)
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription usage progress
|
||||
@@ -145,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscription)
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||
@@ -172,7 +177,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
response.Success(c, dto.BulkAssignResultFromService(result))
|
||||
}
|
||||
|
||||
// Extend handles extending a subscription
|
||||
@@ -196,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscription)
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
@@ -234,7 +239,11 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
|
||||
// ListByUser handles listing subscriptions for a specific user
|
||||
@@ -252,15 +261,18 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscriptions)
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// Helper function to get admin ID from context
|
||||
func getAdminIDFromContext(c *gin.Context) int64 {
|
||||
if user, exists := c.Get("user"); exists {
|
||||
if u, ok := user.(*model.User); ok && u != nil {
|
||||
return u.ID
|
||||
}
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return 0
|
||||
return subject.UserID
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -39,7 +40,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
|
||||
// Parse filters
|
||||
var userID, apiKeyID int64
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
@@ -58,6 +59,47 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
accountID = id
|
||||
}
|
||||
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
groupID = id
|
||||
}
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
stream = &val
|
||||
}
|
||||
|
||||
var billingType *int8
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
bt := int8(val)
|
||||
billingType = &bt
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
@@ -82,10 +124,15 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
UserID: userID,
|
||||
ApiKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
@@ -94,7 +141,11 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, records, result.Total, page, pageSize)
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromService(&records[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// Stats handles getting usage statistics with filters
|
||||
|
||||
@@ -3,6 +3,7 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -68,7 +69,11 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, users, total, page, pageSize)
|
||||
out := make([]dto.User, 0, len(users))
|
||||
for i := range users {
|
||||
out = append(out, *dto.UserFromService(&users[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a user by ID
|
||||
@@ -86,7 +91,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Create handles creating a new user
|
||||
@@ -113,7 +118,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Update handles updating a user
|
||||
@@ -148,7 +153,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// Delete handles deleting a user
|
||||
@@ -190,7 +195,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
@@ -210,7 +215,11 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, keys, total, page, pageSize)
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetUserUsage handles getting user's usage statistics
|
||||
|
||||
@@ -3,9 +3,10 @@ package handler
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -40,42 +41,34 @@ type UpdateAPIKeyRequest struct {
|
||||
// List handles listing user's API keys with pagination
|
||||
// GET /api/v1/api-keys
|
||||
func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, keys, result.Total, page, pageSize)
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a single API key
|
||||
// GET /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -92,26 +85,20 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if key.UserID != user.ID {
|
||||
if key.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this key")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, key)
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Create handles creating a new API key
|
||||
// POST /api/v1/api-keys
|
||||
func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -126,27 +113,21 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, key)
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
// PUT /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -171,27 +152,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
svcReq.Status = &req.Status
|
||||
}
|
||||
|
||||
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
|
||||
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, key)
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Delete handles deleting an API key
|
||||
// DELETE /api/v1/api-keys/:id
|
||||
func (h *APIKeyHandler) Delete(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -201,7 +176,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
|
||||
err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -213,23 +188,21 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
|
||||
// GetAvailableGroups 获取用户可以绑定的分组列表
|
||||
// GET /api/v1/groups/available
|
||||
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
|
||||
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, groups)
|
||||
out := make([]dto.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
out = append(out, *dto.GroupFromService(&groups[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -11,12 +12,14 @@ import (
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
|
||||
func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,9 +52,9 @@ type LoginRequest struct {
|
||||
|
||||
// AuthResponse 认证响应格式(匹配前端期望)
|
||||
type AuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
User *model.User `json:"user"`
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
User *dto.User `json:"user"`
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
@@ -80,7 +83,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: user,
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -135,24 +138,24 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: user,
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// GetCurrentUser handles getting current authenticated user
|
||||
// GET /api/v1/auth/me
|
||||
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
user, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
}
|
||||
|
||||
310
backend/internal/handler/dto/mappers.go
Normal file
310
backend/internal/handler/dto/mappers.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
func UserFromServiceShallow(u *service.User) *User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Wechat: u.Wechat,
|
||||
Notes: u.Notes,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
AllowedGroups: u.AllowedGroups,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func UserFromService(u *service.User) *User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
out := UserFromServiceShallow(u)
|
||||
if len(u.ApiKeys) > 0 {
|
||||
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
|
||||
for i := range u.ApiKeys {
|
||||
k := u.ApiKeys[i]
|
||||
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
|
||||
}
|
||||
}
|
||||
if len(u.Subscriptions) > 0 {
|
||||
out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions))
|
||||
for i := range u.Subscriptions {
|
||||
s := u.Subscriptions[i]
|
||||
out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &ApiKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
}
|
||||
|
||||
func GroupFromService(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := GroupFromServiceShallow(g)
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
for i := range g.AccountGroups {
|
||||
ag := g.AccountGroups[i]
|
||||
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return &Account{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Credentials: a.Credentials,
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
Priority: a.Priority,
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
Schedulable: a.Schedulable,
|
||||
RateLimitedAt: a.RateLimitedAt,
|
||||
RateLimitResetAt: a.RateLimitResetAt,
|
||||
OverloadUntil: a.OverloadUntil,
|
||||
SessionWindowStart: a.SessionWindowStart,
|
||||
SessionWindowEnd: a.SessionWindowEnd,
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
GroupIDs: a.GroupIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func AccountFromService(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
out := AccountFromServiceShallow(a)
|
||||
out.Proxy = ProxyFromService(a.Proxy)
|
||||
if len(a.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups))
|
||||
for i := range a.AccountGroups {
|
||||
ag := a.AccountGroups[i]
|
||||
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
|
||||
}
|
||||
}
|
||||
if len(a.Groups) > 0 {
|
||||
out.Groups = make([]*Group, 0, len(a.Groups))
|
||||
for _, g := range a.Groups {
|
||||
out.Groups = append(out.Groups, GroupFromServiceShallow(g))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
|
||||
if ag == nil {
|
||||
return nil
|
||||
}
|
||||
return &AccountGroup{
|
||||
AccountID: ag.AccountID,
|
||||
GroupID: ag.GroupID,
|
||||
Priority: ag.Priority,
|
||||
CreatedAt: ag.CreatedAt,
|
||||
Account: AccountFromServiceShallow(ag.Account),
|
||||
Group: GroupFromServiceShallow(ag.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func ProxyFromService(p *service.Proxy) *Proxy {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &Proxy{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
Protocol: p.Protocol,
|
||||
Host: p.Host,
|
||||
Port: p.Port,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
Status: p.Status,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &ProxyWithAccountCount{
|
||||
Proxy: *ProxyFromService(&p.Proxy),
|
||||
AccountCount: p.AccountCount,
|
||||
}
|
||||
}
|
||||
|
||||
func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &RedeemCode{
|
||||
ID: rc.ID,
|
||||
Code: rc.Code,
|
||||
Type: rc.Type,
|
||||
Value: rc.Value,
|
||||
Status: rc.Status,
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
Notes: rc.Notes,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
User: UserFromServiceShallow(rc.User),
|
||||
Group: GroupFromServiceShallow(rc.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
ApiKeyID: l.ApiKeyID,
|
||||
AccountID: l.AccountID,
|
||||
RequestID: l.RequestID,
|
||||
Model: l.Model,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
OutputTokens: l.OutputTokens,
|
||||
CacheCreationTokens: l.CacheCreationTokens,
|
||||
CacheReadTokens: l.CacheReadTokens,
|
||||
CacheCreation5mTokens: l.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: l.CacheCreation1hTokens,
|
||||
InputCost: l.InputCost,
|
||||
OutputCost: l.OutputCost,
|
||||
CacheCreationCost: l.CacheCreationCost,
|
||||
CacheReadCost: l.CacheReadCost,
|
||||
TotalCost: l.TotalCost,
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
DurationMs: l.DurationMs,
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
ApiKey: ApiKeyFromService(l.ApiKey),
|
||||
Account: AccountFromService(l.Account),
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &Setting{
|
||||
ID: s.ID,
|
||||
Key: s.Key,
|
||||
Value: s.Value,
|
||||
UpdatedAt: s.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserSubscription{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
GroupID: sub.GroupID,
|
||||
StartsAt: sub.StartsAt,
|
||||
ExpiresAt: sub.ExpiresAt,
|
||||
Status: sub.Status,
|
||||
DailyWindowStart: sub.DailyWindowStart,
|
||||
WeeklyWindowStart: sub.WeeklyWindowStart,
|
||||
MonthlyWindowStart: sub.MonthlyWindowStart,
|
||||
DailyUsageUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
CreatedAt: sub.CreatedAt,
|
||||
UpdatedAt: sub.UpdatedAt,
|
||||
User: UserFromServiceShallow(sub.User),
|
||||
Group: GroupFromServiceShallow(sub.Group),
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
subs := make([]UserSubscription, 0, len(r.Subscriptions))
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
FailedCount: r.FailedCount,
|
||||
Subscriptions: subs,
|
||||
Errors: r.Errors,
|
||||
}
|
||||
}
|
||||
43
backend/internal/handler/dto/settings.go
Normal file
43
backend/internal/handler/dto/settings.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package dto
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
219
backend/internal/handler/dto/types.go
Normal file
219
backend/internal/handler/dto/types.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package dto
|
||||
|
||||
import "time"
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Wechat string `json:"wechat"`
|
||||
Notes string `json:"notes"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Status string `json:"status"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
ApiKeys []ApiKey `json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
type ApiKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
Status string `json:"status"`
|
||||
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
RateLimitedAt *time.Time `json:"rate_limited_at"`
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
|
||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
SessionWindowStatus string `json:"session_window_status"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
GroupIDs []int64 `json:"group_ids,omitempty"`
|
||||
Groups []*Group `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Priority int `json:"priority"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"-"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64 `json:"account_count"`
|
||||
}
|
||||
|
||||
type RedeemCode struct {
|
||||
ID int64 `json:"id"`
|
||||
Code string `json:"code"`
|
||||
Type string `json:"type"`
|
||||
Value float64 `json:"value"`
|
||||
Status string `json:"status"`
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationTokens int `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int `json:"cache_read_tokens"`
|
||||
|
||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
ApiKey *ApiKey `json:"api_key,omitempty"`
|
||||
Account *Account `json:"account,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
type Setting struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type UserSubscription struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
|
||||
StartsAt time.Time `json:"starts_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Status string `json:"status"`
|
||||
|
||||
DailyWindowStart *time.Time `json:"daily_window_start"`
|
||||
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
|
||||
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
|
||||
|
||||
DailyUsageUSD float64 `json:"daily_usage_usd"`
|
||||
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
||||
|
||||
AssignedBy *int64 `json:"assigned_by"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
Notes string `json:"notes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -10,7 +11,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -22,15 +22,23 @@ import (
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
geminiCompatService *service.GeminiMessagesCompatService
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler {
|
||||
func NewGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
geminiCompatService *service.GeminiMessagesCompatService,
|
||||
userService *service.UserService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
) *GatewayHandler {
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
geminiCompatService: geminiCompatService,
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
|
||||
@@ -47,7 +55,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -82,8 +90,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
@@ -92,10 +100,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 确保在函数退出时减少wait计数
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
@@ -106,7 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
@@ -115,56 +123,163 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 计算粘性会话hash
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
platform := ""
|
||||
if apiKey.Group != nil {
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
if platform == service.PlatformGemini {
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 转发请求
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: user,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles listing available models
|
||||
@@ -198,7 +313,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -223,7 +338,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 余额模式:返回钱包余额
|
||||
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
|
||||
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
|
||||
return
|
||||
@@ -241,7 +356,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
// 逻辑:
|
||||
// 1. 如果日/周/月任一限额达到100%,返回0
|
||||
// 2. 否则返回所有已配置周期中剩余额度的最小值
|
||||
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
|
||||
func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 {
|
||||
var remainingValues []float64
|
||||
|
||||
// 检查日限额
|
||||
@@ -292,6 +407,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
@@ -334,7 +471,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
_, ok = middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -366,7 +503,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
||||
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
||||
return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
||||
return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
|
||||
320
backend/internal/handler/gemini_v1beta_handler.go
Normal file
320
backend/internal/handler/gemini_v1beta_handler.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
|
||||
if err != nil {
|
||||
googleError(c, http.StatusBadGateway, err.Error())
|
||||
return
|
||||
}
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||
return
|
||||
}
|
||||
writeUpstreamResponse(c, res)
|
||||
}
|
||||
|
||||
// GeminiV1BetaGetModel proxies:
|
||||
// GET /v1beta/models/{model}
|
||||
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(c.Param("model"))
|
||||
if modelName == "" {
|
||||
googleError(c, http.StatusBadRequest, "Missing model in URL")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusBadGateway, err.Error())
|
||||
return
|
||||
}
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
writeUpstreamResponse(c, res)
|
||||
}
|
||||
|
||||
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
|
||||
// POST /v1beta/models/{model}:generateContent
|
||||
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
|
||||
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetApiKeyFromContext(c)
|
||||
if !ok || apiKey == nil {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
googleError(c, http.StatusInternalServerError, "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
|
||||
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
|
||||
if err != nil {
|
||||
googleError(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
stream := action == "streamGenerateContent"
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusBadRequest, "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
googleError(c, http.StatusBadRequest, "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// For Gemini native API, do not send Claude-style ping frames.
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
|
||||
|
||||
// 0) wait queue check
|
||||
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||
|
||||
// 1) user concurrency slot
|
||||
streamStarted := false
|
||||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
googleError(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 5) forward (writes response to client)
|
||||
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func parseGeminiModelAction(rest string) (model string, action string, err error) {
|
||||
rest = strings.TrimSpace(rest)
|
||||
if rest == "" {
|
||||
return "", "", &pathParseError{"missing path"}
|
||||
}
|
||||
|
||||
// Standard: {model}:{action}
|
||||
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
|
||||
return rest[:i], rest[i+1:], nil
|
||||
}
|
||||
|
||||
// Fallback: {model}/{action}
|
||||
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
|
||||
return rest[:i], rest[i+1:], nil
|
||||
}
|
||||
|
||||
return "", "", &pathParseError{"invalid model action path"}
|
||||
}
|
||||
|
||||
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
}
|
||||
|
||||
func mapGeminiUpstreamError(statusCode int) (int, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
type pathParseError struct{ msg string }
|
||||
|
||||
func (e *pathParseError) Error() string { return e.msg }
|
||||
|
||||
func googleError(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
|
||||
if res == nil {
|
||||
googleError(c, http.StatusBadGateway, "Empty upstream response")
|
||||
return
|
||||
}
|
||||
for k, vv := range res.Headers {
|
||||
// Avoid overriding content-length and hop-by-hop headers.
|
||||
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
c.Writer.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
contentType := res.Headers.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(res.StatusCode, contentType, res.Body)
|
||||
}
|
||||
|
||||
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
if res == nil {
|
||||
return true
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Setting *admin.SettingHandler
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -46,7 +47,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -94,8 +95,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. Check if wait queue is full
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
@@ -104,10 +105,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// Ensure wait count is decremented when function exits
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
|
||||
// 1. First acquire user concurrency slot
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
@@ -118,7 +119,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
@@ -127,49 +128,74 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: user,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
}()
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
@@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -37,15 +38,9 @@ type RedeemResponse struct {
|
||||
// Redeem handles redeeming a code
|
||||
// POST /api/v1/redeem
|
||||
func (h *RedeemHandler) Redeem(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
|
||||
result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
response.Success(c, dto.RedeemCodeFromService(result))
|
||||
}
|
||||
|
||||
// GetHistory returns the user's redemption history
|
||||
// GET /api/v1/redeem/history
|
||||
func (h *RedeemHandler) GetHistory(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
// Default limit is 25
|
||||
limit := 25
|
||||
|
||||
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
|
||||
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, codes)
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -30,6 +31,17 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
settings.Version = h.version
|
||||
response.Success(c, settings)
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocUrl: settings.DocUrl,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct {
|
||||
|
||||
// SubscriptionProgressInfo represents subscription with progress info
|
||||
type SubscriptionProgressInfo struct {
|
||||
Subscription *model.UserSubscription `json:"subscription"`
|
||||
Subscription *dto.UserSubscription `json:"subscription"`
|
||||
Progress *service.SubscriptionProgress `json:"progress"`
|
||||
}
|
||||
|
||||
@@ -44,68 +45,58 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S
|
||||
// List handles listing current user's subscriptions
|
||||
// GET /api/v1/subscriptions
|
||||
func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := user.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user in context")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
|
||||
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscriptions)
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetActive handles getting current user's active subscriptions
|
||||
// GET /api/v1/subscriptions/active
|
||||
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := user.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user in context")
|
||||
return
|
||||
}
|
||||
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, subscriptions)
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription progress for current user
|
||||
// GET /api/v1/subscriptions/progress
|
||||
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := user.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all active subscriptions with progress
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
result = append(result, SubscriptionProgressInfo{
|
||||
Subscription: sub,
|
||||
Subscription: dto.UserSubscriptionFromService(sub),
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
@@ -131,20 +122,14 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
|
||||
// GetSummary handles getting a summary of current user's subscription status
|
||||
// GET /api/v1/subscriptions/summary
|
||||
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := user.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user in context")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all active subscriptions
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
|
||||
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -30,15 +32,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.
|
||||
// List handles listing usage records with pagination
|
||||
// GET /api/v1/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -58,7 +54,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != user.ID {
|
||||
if apiKey.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's usage records")
|
||||
return
|
||||
}
|
||||
@@ -66,36 +62,82 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
apiKeyID = id
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
var records []model.UsageLog
|
||||
var result *pagination.PaginationResult
|
||||
var err error
|
||||
// Parse additional filters
|
||||
model := c.Query("model")
|
||||
|
||||
if apiKeyID > 0 {
|
||||
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
|
||||
} else {
|
||||
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
stream = &val
|
||||
}
|
||||
|
||||
var billingType *int8
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
bt := int8(val)
|
||||
billingType = &bt
|
||||
}
|
||||
|
||||
// Parse date range
|
||||
var startTime, endTime *time.Time
|
||||
if startDateStr := c.Query("start_date"); startDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
startTime = &t
|
||||
}
|
||||
|
||||
if endDateStr := c.Query("end_date"); endDateStr != "" {
|
||||
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
// Set end time to end of day
|
||||
t = t.Add(24*time.Hour - time.Nanosecond)
|
||||
endTime = &t
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
ApiKeyID: apiKeyID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, records, result.Total, page, pageSize)
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromService(&records[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a single usage record
|
||||
// GET /api/v1/usage/:id
|
||||
func (h *UsageHandler) GetByID(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,26 +154,20 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if record.UserID != user.ID {
|
||||
if record.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this record")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, record)
|
||||
response.Success(c, dto.UsageLogFromService(record))
|
||||
}
|
||||
|
||||
// Stats handles getting usage statistics
|
||||
// GET /api/v1/usage/stats
|
||||
func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -149,7 +185,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
response.NotFound(c, "API key not found")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != user.ID {
|
||||
if apiKey.UserID != subject.UserID {
|
||||
response.Forbidden(c, "Not authorized to access this API key's statistics")
|
||||
return
|
||||
}
|
||||
@@ -201,7 +237,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
if apiKeyID > 0 {
|
||||
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
} else {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
}
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -245,19 +281,13 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
|
||||
// DashboardStats handles getting user dashboard statistics
|
||||
// GET /api/v1/usage/dashboard/stats
|
||||
func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
|
||||
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -269,22 +299,16 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
|
||||
// DashboardTrend handles getting user usage trend data
|
||||
// GET /api/v1/usage/dashboard/trend
|
||||
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
|
||||
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
|
||||
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -301,21 +325,15 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
|
||||
// DashboardModels handles getting user model usage statistics
|
||||
// GET /api/v1/usage/dashboard/models
|
||||
func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime := parseUserTimeRange(c)
|
||||
|
||||
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
|
||||
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -336,15 +354,9 @@ type BatchApiKeysUsageRequest struct {
|
||||
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
|
||||
// POST /api/v1/usage/dashboard/api-keys-usage
|
||||
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -359,24 +371,16 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify ownership of all requested API keys
|
||||
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||
if len(req.ApiKeyIDs) > 100 {
|
||||
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||
return
|
||||
}
|
||||
|
||||
userApiKeyIDs := make(map[int64]bool)
|
||||
for _, key := range userApiKeys {
|
||||
userApiKeyIDs[key.ID] = true
|
||||
}
|
||||
|
||||
// Filter to only include user's own API keys
|
||||
validApiKeyIDs := make([]int64, 0)
|
||||
for _, id := range req.ApiKeyIDs {
|
||||
if userApiKeyIDs[id] {
|
||||
validApiKeyIDs = append(validApiKeyIDs, id)
|
||||
}
|
||||
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -35,19 +36,13 @@ type UpdateProfileRequest struct {
|
||||
// GetProfile handles getting user profile
|
||||
// GET /api/v1/users/me
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
|
||||
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, userData)
|
||||
response.Success(c, dto.UserFromService(userData))
|
||||
}
|
||||
|
||||
// ChangePassword handles changing user password
|
||||
// POST /api/v1/users/me/password
|
||||
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
CurrentPassword: req.OldPassword,
|
||||
NewPassword: req.NewPassword,
|
||||
}
|
||||
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
|
||||
err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
// UpdateProfile handles updating user profile
|
||||
// PUT /api/v1/users/me
|
||||
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
userValue, exists := c.Get("user")
|
||||
if !exists {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := userValue.(*model.User)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.InternalError(c, "Invalid user context")
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
Username: req.Username,
|
||||
Wechat: req.Wechat,
|
||||
}
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
|
||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, updatedUser)
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
@@ -29,6 +30,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Setting: settingHandler,
|
||||
@@ -95,6 +97,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewSettingHandler,
|
||||
|
||||
@@ -2,8 +2,8 @@ package infrastructure
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
|
||||
// 自动迁移(始终执行,确保数据库结构与代码同步)
|
||||
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
|
||||
if err := model.AutoMigrate(db); err != nil {
|
||||
if err := repository.AutoMigrate(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,415 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// JSONB 用于存储JSONB数据
|
||||
type JSONB map[string]any
|
||||
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (j *JSONB) Scan(value any) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("type assertion to []byte failed")
|
||||
}
|
||||
return json.Unmarshal(bytes, j)
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"size:100;not null" json:"name"`
|
||||
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
|
||||
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
|
||||
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
|
||||
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
||||
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
||||
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
||||
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 调度控制
|
||||
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
|
||||
|
||||
// 限流状态 (429)
|
||||
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
|
||||
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
|
||||
|
||||
// 过载状态 (529)
|
||||
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
|
||||
|
||||
// 5小时时间窗口
|
||||
SessionWindowStart *time.Time `json:"session_window_start"`
|
||||
SessionWindowEnd *time.Time `json:"session_window_end"`
|
||||
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
|
||||
|
||||
// 关联
|
||||
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
|
||||
|
||||
// 虚拟字段 (不存储到数据库)
|
||||
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
|
||||
Groups []*Group `gorm:"-" json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func (Account) TableName() string {
|
||||
return "accounts"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (a *Account) IsActive() bool {
|
||||
return a.Status == "active"
|
||||
}
|
||||
|
||||
// IsSchedulable 检查账号是否可调度
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
|
||||
return false
|
||||
}
|
||||
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsRateLimited 检查是否处于限流状态
|
||||
func (a *Account) IsRateLimited() bool {
|
||||
if a.RateLimitResetAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*a.RateLimitResetAt)
|
||||
}
|
||||
|
||||
// IsOverloaded 检查是否处于过载状态
|
||||
func (a *Account) IsOverloaded() bool {
|
||||
if a.OverloadUntil == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*a.OverloadUntil)
|
||||
}
|
||||
|
||||
// IsOAuth 检查是否为OAuth类型账号(包括oauth和setup-token)
|
||||
func (a *Account) IsOAuth() bool {
|
||||
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
|
||||
}
|
||||
|
||||
// CanGetUsage 检查账号是否可以获取usage信息(只有oauth类型可以,setup-token没有profile权限)
|
||||
func (a *Account) CanGetUsage() bool {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// GetCredential 获取凭证字段
|
||||
func (a *Account) GetCredential(key string) string {
|
||||
if a.Credentials == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Credentials[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetModelMapping 获取模型映射配置
|
||||
// 返回格式: map[请求模型名]实际模型名
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["model_mapping"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
// 处理map[string]interface{}类型
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsModelSupported 检查请求的模型是否被该账号支持
|
||||
// 如果没有设置模型映射,则支持所有模型
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return true // 没有映射配置,支持所有模型
|
||||
}
|
||||
_, exists := mapping[requestedModel]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetMappedModel 获取映射后的实际模型名
|
||||
// 如果没有映射,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// GetBaseURL 获取API基础URL(用于apikey类型账号)
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeApiKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
return "https://api.anthropic.com" // 默认URL
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
// GetExtraString 从Extra字段获取字符串值
|
||||
func (a *Account) GetExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeApiKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCustomErrorCodes 获取自定义错误码列表
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["custom_error_codes"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
// 处理 []interface{} 类型(JSON反序列化后的格式)
|
||||
if arr, ok := raw.([]any); ok {
|
||||
result := make([]int, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
// JSON 数字默认解析为 float64
|
||||
if f, ok := v.(float64); ok {
|
||||
result = append(result, int(f))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
|
||||
// 如果未启用自定义错误码或列表为空,返回 true(使用默认策略)
|
||||
// 如果启用且列表非空,只有在列表中的错误码才返回 true
|
||||
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
|
||||
if !a.IsCustomErrorCodesEnabled() {
|
||||
return true // 未启用,使用默认策略
|
||||
}
|
||||
codes := a.GetCustomErrorCodes()
|
||||
if len(codes) == 0 {
|
||||
return true // 启用但列表为空,fallback到默认策略
|
||||
}
|
||||
// 检查是否在自定义列表中
|
||||
for _, code := range codes {
|
||||
if code == statusCode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
|
||||
// 启用后,标题生成、Warmup等预热请求将返回mock响应,不消耗上游token
|
||||
func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// =============== OpenAI 相关方法 ===============
|
||||
|
||||
// IsOpenAI 检查是否为 OpenAI 平台账号
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
return a.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
// IsAnthropic 检查是否为 Anthropic 平台账号
|
||||
func (a *Account) IsAnthropic() bool {
|
||||
return a.Platform == PlatformAnthropic
|
||||
}
|
||||
|
||||
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
|
||||
func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号(Response 账号)
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
}
|
||||
|
||||
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
|
||||
// 对于 API Key 类型账号,从 credentials 中获取 base_url
|
||||
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
}
|
||||
}
|
||||
return "https://api.openai.com" // OpenAI 默认 API URL
|
||||
}
|
||||
|
||||
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
|
||||
func (a *Account) GetOpenAIAccessToken() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("access_token")
|
||||
}
|
||||
|
||||
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
|
||||
func (a *Account) GetOpenAIRefreshToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("refresh_token")
|
||||
}
|
||||
|
||||
// GetOpenAIIDToken 获取 OpenAI ID Token(JWT,包含用户信息)
|
||||
func (a *Account) GetOpenAIIDToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
// GetOpenAIApiKey 获取 OpenAI API Key(用于 Response 账号)
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
}
|
||||
|
||||
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
|
||||
// 返回空字符串表示透传原始 User-Agent
|
||||
func (a *Account) GetOpenAIUserAgent() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("user_agent")
|
||||
}
|
||||
|
||||
// GetChatGPTAccountID 获取 ChatGPT 账号 ID(从 ID Token 解析)
|
||||
func (a *Account) GetChatGPTAccountID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_account_id")
|
||||
}
|
||||
|
||||
// GetChatGPTUserID 获取 ChatGPT 用户 ID(从 ID Token 解析)
|
||||
func (a *Account) GetChatGPTUserID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_user_id")
|
||||
}
|
||||
|
||||
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
|
||||
func (a *Account) GetOpenAIOrganizationID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("organization_id")
|
||||
}
|
||||
|
||||
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
|
||||
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return nil
|
||||
}
|
||||
expiresAtStr := a.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return nil
|
||||
}
|
||||
// 尝试解析时间
|
||||
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||
if err != nil {
|
||||
// 尝试解析为 Unix 时间戳
|
||||
if v, ok := a.Credentials["expires_at"].(float64); ok {
|
||||
t = time.Unix(int64(v), 0)
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
|
||||
func (a *Account) IsOpenAITokenExpired() bool {
|
||||
expiresAt := a.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false // 没有过期时间信息,假设未过期
|
||||
}
|
||||
// 提前 60 秒认为过期,便于刷新
|
||||
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64 `gorm:"primaryKey" json:"account_id"`
|
||||
GroupID int64 `gorm:"primaryKey" json:"group_id"`
|
||||
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func (AccountGroup) TableName() string {
|
||||
return "account_groups"
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKey struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
UserID int64 `gorm:"index;not null" json:"user_id"`
|
||||
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
|
||||
Name string `gorm:"size:100;not null" json:"name"`
|
||||
GroupID *int64 `gorm:"index" json:"group_id"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func (ApiKey) TableName() string {
|
||||
return "api_keys"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (k *ApiKey) IsActive() bool {
|
||||
return k.Status == "active"
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 订阅类型常量
|
||||
const (
|
||||
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
|
||||
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
|
||||
// 订阅功能字段
|
||||
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
||||
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
|
||||
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
|
||||
|
||||
// 虚拟字段 (不存储到数据库)
|
||||
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
|
||||
}
|
||||
|
||||
func (Group) TableName() string {
|
||||
return "groups"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (g *Group) IsActive() bool {
|
||||
return g.Status == "active"
|
||||
}
|
||||
|
||||
// IsSubscriptionType 检查是否为订阅类型分组
|
||||
func (g *Group) IsSubscriptionType() bool {
|
||||
return g.SubscriptionType == SubscriptionTypeSubscription
|
||||
}
|
||||
|
||||
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
|
||||
func (g *Group) IsFreeSubscription() bool {
|
||||
return g.IsSubscriptionType() && g.RateMultiplier == 0
|
||||
}
|
||||
|
||||
// HasDailyLimit 检查是否有日限额
|
||||
func (g *Group) HasDailyLimit() bool {
|
||||
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
|
||||
}
|
||||
|
||||
// HasWeeklyLimit 检查是否有周限额
|
||||
func (g *Group) HasWeeklyLimit() bool {
|
||||
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
|
||||
}
|
||||
|
||||
// HasMonthlyLimit 检查是否有月限额
|
||||
func (g *Group) HasMonthlyLimit() bool {
|
||||
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AutoMigrate 自动迁移所有模型
|
||||
func AutoMigrate(db *gorm.DB) error {
|
||||
return db.AutoMigrate(
|
||||
&User{},
|
||||
&ApiKey{},
|
||||
&Group{},
|
||||
&Account{},
|
||||
&AccountGroup{},
|
||||
&Proxy{},
|
||||
&RedeemCode{},
|
||||
&UsageLog{},
|
||||
&Setting{},
|
||||
&UserSubscription{},
|
||||
)
|
||||
}
|
||||
|
||||
// 状态常量
|
||||
const (
|
||||
StatusActive = "active"
|
||||
StatusDisabled = "disabled"
|
||||
StatusError = "error"
|
||||
StatusUnused = "unused"
|
||||
StatusUsed = "used"
|
||||
StatusExpired = "expired"
|
||||
)
|
||||
|
||||
// 角色常量
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// 平台常量
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
)
|
||||
|
||||
// 账号类型常量
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeApiKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// 卡密类型常量
|
||||
const (
|
||||
RedeemTypeBalance = "balance"
|
||||
RedeemTypeConcurrency = "concurrency"
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// 管理员调整类型常量
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"size:100;not null" json:"name"`
|
||||
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
|
||||
Host string `gorm:"size:255;not null" json:"host"`
|
||||
Port int `gorm:"not null" json:"port"`
|
||||
Username string `gorm:"size:100" json:"username"`
|
||||
Password string `gorm:"size:100" json:"-"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
|
||||
func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (p *Proxy) IsActive() bool {
|
||||
return p.Status == "active"
|
||||
}
|
||||
|
||||
// URL 返回代理URL
|
||||
func (p *Proxy) URL() string {
|
||||
if p.Username != "" && p.Password != "" {
|
||||
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
|
||||
}
|
||||
|
||||
// ProxyWithAccountCount extends Proxy with account count information
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64 `json:"account_count"`
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RedeemCode struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `gorm:"type:text" json:"notes"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
|
||||
// 订阅类型专用字段
|
||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func (RedeemCode) TableName() string {
|
||||
return "redeem_codes"
|
||||
}
|
||||
|
||||
// IsUsed 检查是否已使用
|
||||
func (r *RedeemCode) IsUsed() bool {
|
||||
return r.Status == "used"
|
||||
}
|
||||
|
||||
// CanUse 检查是否可以使用
|
||||
func (r *RedeemCode) CanUse() bool {
|
||||
return r.Status == "unused"
|
||||
}
|
||||
|
||||
// GenerateRedeemCode 生成唯一的兑换码
|
||||
func GenerateRedeemCode() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Setting 系统设置模型(Key-Value存储)
|
||||
type Setting struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
|
||||
Value string `gorm:"type:text;not null" json:"value"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (Setting) TableName() string {
|
||||
return "settings"
|
||||
}
|
||||
|
||||
// 设置Key常量
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
)
|
||||
|
||||
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
|
||||
// SystemSettings 系统设置结构体(用于API响应)
|
||||
type SystemSettings struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
}
|
||||
|
||||
// PublicSettings 公开设置(无需登录即可获取)
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// 消费类型常量
|
||||
const (
|
||||
BillingTypeBalance int8 = 0 // 钱包余额
|
||||
BillingTypeSubscription int8 = 1 // 订阅套餐
|
||||
)
|
||||
|
||||
type UsageLog struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
UserID int64 `gorm:"index;not null" json:"user_id"`
|
||||
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
|
||||
AccountID int64 `gorm:"index;not null" json:"account_id"`
|
||||
RequestID string `gorm:"size:64" json:"request_id"`
|
||||
Model string `gorm:"size:100;index;not null" json:"model"`
|
||||
|
||||
// 订阅关联(可选)
|
||||
GroupID *int64 `gorm:"index" json:"group_id"`
|
||||
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
|
||||
|
||||
// Token使用量(4类)
|
||||
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
|
||||
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
|
||||
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
|
||||
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
|
||||
|
||||
// 详细的缓存创建分类
|
||||
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
|
||||
|
||||
// 费用(USD)
|
||||
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
|
||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
||||
|
||||
// 元数据
|
||||
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
|
||||
Stream bool `gorm:"default:false;not null" json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
|
||||
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
func (UsageLog) TableName() string {
|
||||
return "usage_logs"
|
||||
}
|
||||
|
||||
// TotalTokens 总token数
|
||||
func (u *UsageLog) TotalTokens() int {
|
||||
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||
Username string `gorm:"size:100;default:''" json:"username"`
|
||||
Wechat string `gorm:"size:100;default:''" json:"wechat"`
|
||||
Notes string `gorm:"type:text;default:''" json:"notes"`
|
||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 关联
|
||||
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// IsAdmin 检查是否管理员
|
||||
func (u *User) IsAdmin() bool {
|
||||
return u.Role == "admin"
|
||||
}
|
||||
|
||||
// IsActive 检查是否激活
|
||||
func (u *User) IsActive() bool {
|
||||
return u.Status == "active"
|
||||
}
|
||||
|
||||
// CanBindGroup 检查是否可以绑定指定分组
|
||||
// 对于标准类型分组:
|
||||
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
|
||||
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
|
||||
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
|
||||
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
|
||||
return !isExclusive
|
||||
}
|
||||
|
||||
// SetPassword 设置密码(哈希存储)
|
||||
func (u *User) SetPassword(password string) error {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.PasswordHash = string(hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPassword 验证密码
|
||||
func (u *User) CheckPassword(password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
@@ -1,157 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// 订阅状态常量
|
||||
const (
|
||||
SubscriptionStatusActive = "active"
|
||||
SubscriptionStatusExpired = "expired"
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
)
|
||||
|
||||
// UserSubscription 用户订阅模型
|
||||
type UserSubscription struct {
|
||||
ID int64 `gorm:"primaryKey" json:"id"`
|
||||
UserID int64 `gorm:"index;not null" json:"user_id"`
|
||||
GroupID int64 `gorm:"index;not null" json:"group_id"`
|
||||
|
||||
// 订阅有效期
|
||||
StartsAt time.Time `gorm:"not null" json:"starts_at"`
|
||||
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
|
||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
|
||||
|
||||
// 滑动窗口起始时间(nil = 未激活)
|
||||
DailyWindowStart *time.Time `json:"daily_window_start"`
|
||||
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
|
||||
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
|
||||
|
||||
// 当前窗口已用额度(USD,基于 total_cost 计算)
|
||||
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
|
||||
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
|
||||
|
||||
// 管理员分配信息
|
||||
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
|
||||
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
|
||||
Notes string `gorm:"type:text" json:"notes"`
|
||||
|
||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
|
||||
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
func (UserSubscription) TableName() string {
|
||||
return "user_subscriptions"
|
||||
}
|
||||
|
||||
// IsActive 检查订阅是否有效(状态为active且未过期)
|
||||
func (s *UserSubscription) IsActive() bool {
|
||||
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired 检查订阅是否已过期
|
||||
func (s *UserSubscription) IsExpired() bool {
|
||||
return time.Now().After(s.ExpiresAt)
|
||||
}
|
||||
|
||||
// DaysRemaining 返回订阅剩余天数
|
||||
func (s *UserSubscription) DaysRemaining() int {
|
||||
if s.IsExpired() {
|
||||
return 0
|
||||
}
|
||||
return int(time.Until(s.ExpiresAt).Hours() / 24)
|
||||
}
|
||||
|
||||
// IsWindowActivated 检查窗口是否已激活
|
||||
func (s *UserSubscription) IsWindowActivated() bool {
|
||||
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
|
||||
}
|
||||
|
||||
// NeedsDailyReset 检查日窗口是否需要重置
|
||||
func (s *UserSubscription) NeedsDailyReset() bool {
|
||||
if s.DailyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
|
||||
}
|
||||
|
||||
// NeedsWeeklyReset 检查周窗口是否需要重置
|
||||
func (s *UserSubscription) NeedsWeeklyReset() bool {
|
||||
if s.WeeklyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
|
||||
}
|
||||
|
||||
// NeedsMonthlyReset 检查月窗口是否需要重置
|
||||
func (s *UserSubscription) NeedsMonthlyReset() bool {
|
||||
if s.MonthlyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
|
||||
}
|
||||
|
||||
// DailyResetTime 返回日窗口重置时间
|
||||
func (s *UserSubscription) DailyResetTime() *time.Time {
|
||||
if s.DailyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.DailyWindowStart.Add(24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
// WeeklyResetTime 返回周窗口重置时间
|
||||
func (s *UserSubscription) WeeklyResetTime() *time.Time {
|
||||
if s.WeeklyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
// MonthlyResetTime 返回月窗口重置时间
|
||||
func (s *UserSubscription) MonthlyResetTime() *time.Time {
|
||||
if s.MonthlyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
// CheckDailyLimit 检查是否超出日限额
|
||||
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasDailyLimit() {
|
||||
return true // 无限制
|
||||
}
|
||||
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
|
||||
}
|
||||
|
||||
// CheckWeeklyLimit 检查是否超出周限额
|
||||
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasWeeklyLimit() {
|
||||
return true // 无限制
|
||||
}
|
||||
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
|
||||
}
|
||||
|
||||
// CheckMonthlyLimit 检查是否超出月限额
|
||||
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasMonthlyLimit() {
|
||||
return true // 无限制
|
||||
}
|
||||
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
|
||||
}
|
||||
|
||||
// CheckAllLimits 检查所有限额
|
||||
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
|
||||
daily = s.CheckDailyLimit(group, additionalCost)
|
||||
weekly = s.CheckWeeklyLimit(group, additionalCost)
|
||||
monthly = s.CheckMonthlyLimit(group, additionalCost)
|
||||
return
|
||||
}
|
||||
42
backend/internal/pkg/gemini/models.go
Normal file
42
backend/internal/pkg/gemini/models.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package gemini
|
||||
|
||||
// This package provides minimal fallback model metadata for Gemini native endpoints.
|
||||
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||
}
|
||||
|
||||
type ModelsListResponse struct {
|
||||
Models []Model `json:"models"`
|
||||
}
|
||||
|
||||
func DefaultModels() []Model {
|
||||
methods := []string{"generateContent", "streamGenerateContent"}
|
||||
return []Model{
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.0-flash-lite", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
func FallbackModelsList() ModelsListResponse {
|
||||
return ModelsListResponse{Models: DefaultModels()}
|
||||
}
|
||||
|
||||
func FallbackModel(model string) Model {
|
||||
methods := []string{"generateContent", "streamGenerateContent"}
|
||||
if model == "" {
|
||||
return Model{Name: "models/unknown", SupportedGenerationMethods: methods}
|
||||
}
|
||||
if len(model) >= 7 && model[:7] == "models/" {
|
||||
return Model{Name: model, SupportedGenerationMethods: methods}
|
||||
}
|
||||
return Model{Name: "models/" + model, SupportedGenerationMethods: methods}
|
||||
}
|
||||
38
backend/internal/pkg/geminicli/codeassist_types.go
Normal file
38
backend/internal/pkg/geminicli/codeassist_types.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package geminicli
|
||||
|
||||
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
|
||||
type LoadCodeAssistRequest struct {
|
||||
Metadata LoadCodeAssistMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
type LoadCodeAssistMetadata struct {
|
||||
IDEType string `json:"ideType"`
|
||||
Platform string `json:"platform"`
|
||||
PluginType string `json:"pluginType"`
|
||||
}
|
||||
|
||||
type LoadCodeAssistResponse struct {
|
||||
CurrentTier string `json:"currentTier,omitempty"`
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
|
||||
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
|
||||
}
|
||||
|
||||
type AllowedTier struct {
|
||||
ID string `json:"id"`
|
||||
IsDefault bool `json:"isDefault,omitempty"`
|
||||
}
|
||||
|
||||
type OnboardUserRequest struct {
|
||||
TierID string `json:"tierId"`
|
||||
Metadata LoadCodeAssistMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
type OnboardUserResponse struct {
|
||||
Done bool `json:"done"`
|
||||
Response *OnboardUserResultData `json:"response,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type OnboardUserResultData struct {
|
||||
CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"`
|
||||
}
|
||||
42
backend/internal/pkg/geminicli/constants.go
Normal file
42
backend/internal/pkg/geminicli/constants.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package geminicli
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
AIStudioBaseURL = "https://generativelanguage.googleapis.com"
|
||||
GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
|
||||
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TokenURL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
// AIStudioOAuthRedirectURI is the default redirect URI used for AI Studio OAuth.
|
||||
// This matches the "copy/paste callback URL" flow used by OpenAI OAuth in this project.
|
||||
// Note: You still need to register this redirect URI in your Google OAuth client
|
||||
// unless you use an OAuth client type that permits localhost redirect URIs.
|
||||
AIStudioOAuthRedirectURI = "http://localhost:1455/auth/callback"
|
||||
|
||||
// DefaultScopes for Code Assist (includes cloud-platform for API access plus userinfo scopes)
|
||||
// Required by Google's Code Assist API.
|
||||
DefaultCodeAssistScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
|
||||
|
||||
// DefaultScopes for AI Studio (uses generativelanguage API with OAuth)
|
||||
// Reference: https://ai.google.dev/gemini-api/docs/oauth
|
||||
// For regular Google accounts, supports API calls to generativelanguage.googleapis.com
|
||||
// Note: Google Auth platform currently documents the OAuth scope as
|
||||
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
|
||||
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
|
||||
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
|
||||
|
||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
|
||||
GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)"
|
||||
)
|
||||
21
backend/internal/pkg/geminicli/models.go
Normal file
21
backend/internal/pkg/geminicli/models.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package geminicli
|
||||
|
||||
// Model represents a selectable Gemini model for UI/testing purposes.
|
||||
// Keep JSON fields consistent with existing frontend expectations.
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gemini-3-pro", Type: "model", DisplayName: "Gemini 3 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
const DefaultTestModel = "gemini-2.5-pro"
|
||||
243
backend/internal/pkg/geminicli/oauth.go
Normal file
243
backend/internal/pkg/geminicli/oauth.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OAuthConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Scopes string
|
||||
}
|
||||
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (s *SessionStore) Stop() {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
close(s.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
func base64URLEncode(data []byte) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
||||
}
|
||||
|
||||
// EffectiveOAuthConfig returns the effective OAuth configuration.
|
||||
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty).
|
||||
//
|
||||
// If ClientID/ClientSecret is not provided, this falls back to the built-in Gemini CLI OAuth client.
|
||||
//
|
||||
// Note: The built-in Gemini CLI OAuth client is restricted and may reject some scopes (e.g.
|
||||
// https://www.googleapis.com/auth/generative-language), which will surface as
|
||||
// "restricted_client" / "Unregistered scope(s)" errors during browser authorization.
|
||||
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
|
||||
effective := OAuthConfig{
|
||||
ClientID: strings.TrimSpace(cfg.ClientID),
|
||||
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
|
||||
Scopes: strings.TrimSpace(cfg.Scopes),
|
||||
}
|
||||
|
||||
// Normalize scopes: allow comma-separated input but send space-delimited scopes to Google.
|
||||
if effective.Scopes != "" {
|
||||
effective.Scopes = strings.Join(strings.Fields(strings.ReplaceAll(effective.Scopes, ",", " ")), " ")
|
||||
}
|
||||
|
||||
// Fall back to built-in Gemini CLI OAuth client when not configured.
|
||||
if effective.ClientID == "" && effective.ClientSecret == "" {
|
||||
effective.ClientID = GeminiCLIOAuthClientID
|
||||
effective.ClientSecret = GeminiCLIOAuthClientSecret
|
||||
} else if effective.ClientID == "" || effective.ClientSecret == "" {
|
||||
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
|
||||
}
|
||||
|
||||
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
|
||||
effective.ClientSecret == GeminiCLIOAuthClientSecret
|
||||
|
||||
if effective.Scopes == "" {
|
||||
// Use different default scopes based on OAuth type
|
||||
if oauthType == "ai_studio" {
|
||||
// Built-in client can't request some AI Studio scopes (notably generative-language).
|
||||
if isBuiltinClient {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = DefaultAIStudioScopes
|
||||
}
|
||||
} else {
|
||||
// Default to Code Assist scopes
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
}
|
||||
} else if oauthType == "ai_studio" && isBuiltinClient {
|
||||
// If user overrides scopes while still using the built-in client, strip restricted scopes.
|
||||
parts := strings.Fields(effective.Scopes)
|
||||
filtered := make([]string, 0, len(parts))
|
||||
for _, s := range parts {
|
||||
if strings.Contains(s, "generative-language") {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, s)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
effective.Scopes = DefaultCodeAssistScopes
|
||||
} else {
|
||||
effective.Scopes = strings.Join(filtered, " ")
|
||||
}
|
||||
}
|
||||
|
||||
// Backward compatibility: normalize older AI Studio scope to the currently documented one.
|
||||
if oauthType == "ai_studio" && effective.Scopes != "" {
|
||||
parts := strings.Fields(effective.Scopes)
|
||||
for i := range parts {
|
||||
if parts[i] == "https://www.googleapis.com/auth/generative-language" {
|
||||
parts[i] = "https://www.googleapis.com/auth/generative-language.retriever"
|
||||
}
|
||||
}
|
||||
effective.Scopes = strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
return effective, nil
|
||||
}
|
||||
|
||||
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
|
||||
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
redirectURI = strings.TrimSpace(redirectURI)
|
||||
if redirectURI == "" {
|
||||
return "", fmt.Errorf("redirect_uri is required")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", effectiveCfg.ClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", effectiveCfg.Scopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("include_granted_scopes", "true")
|
||||
if strings.TrimSpace(projectID) != "" {
|
||||
params.Set("project_id", strings.TrimSpace(projectID))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
|
||||
}
|
||||
46
backend/internal/pkg/geminicli/sanitize.go
Normal file
46
backend/internal/pkg/geminicli/sanitize.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package geminicli
|
||||
|
||||
import "strings"
|
||||
|
||||
const maxLogBodyLen = 2048
|
||||
|
||||
func SanitizeBodyForLogs(body string) string {
|
||||
body = truncateBase64InMessage(body)
|
||||
if len(body) > maxLogBodyLen {
|
||||
body = body[:maxLogBodyLen] + "...[truncated]"
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func truncateBase64InMessage(message string) string {
|
||||
const maxBase64Length = 50
|
||||
|
||||
result := message
|
||||
offset := 0
|
||||
for {
|
||||
idx := strings.Index(result[offset:], ";base64,")
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
actualIdx := offset + idx
|
||||
start := actualIdx + len(";base64,")
|
||||
|
||||
end := start
|
||||
for end < len(result) && isBase64Char(result[end]) {
|
||||
end++
|
||||
}
|
||||
|
||||
if end-start > maxBase64Length {
|
||||
result = result[:start+maxBase64Length] + "...[truncated]" + result[end:]
|
||||
offset = start + maxBase64Length + len("...[truncated]")
|
||||
continue
|
||||
}
|
||||
offset = end
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func isBase64Char(c byte) bool {
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '='
|
||||
}
|
||||
9
backend/internal/pkg/geminicli/token_types.go
Normal file
9
backend/internal/pkg/geminicli/token_types.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package geminicli
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
24
backend/internal/pkg/googleapi/status.go
Normal file
24
backend/internal/pkg/googleapi/status.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package googleapi
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPStatusToGoogleStatus maps HTTP status codes to Google-style error status strings.
|
||||
func HTTPStatusToGoogleStatus(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "INVALID_ARGUMENT"
|
||||
case http.StatusUnauthorized:
|
||||
return "UNAUTHENTICATED"
|
||||
case http.StatusForbidden:
|
||||
return "PERMISSION_DENIED"
|
||||
case http.StatusNotFound:
|
||||
return "NOT_FOUND"
|
||||
case http.StatusTooManyRequests:
|
||||
return "RESOURCE_EXHAUSTED"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "INTERNAL"
|
||||
}
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
@@ -127,10 +127,15 @@ type UserDashboardStats struct {
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
|
||||
@@ -5,10 +5,10 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
@@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository {
|
||||
return &accountRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Create(account).Error
|
||||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||||
m := accountModelFromService(account)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyAccountModelToService(account, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
|
||||
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
var m accountModel
|
||||
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
// 填充 GroupIDs 和 Groups 虚拟字段
|
||||
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
|
||||
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
|
||||
for _, ag := range account.AccountGroups {
|
||||
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
account.Groups = append(account.Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
return &account, nil
|
||||
return accountModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
|
||||
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
||||
if crsAccountID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var account model.Account
|
||||
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
|
||||
var m accountModel
|
||||
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
return accountModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
|
||||
return r.db.WithContext(ctx).Save(account).Error
|
||||
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
|
||||
m := accountModelFromService(account)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyAccountModelToService(account, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 先删除账号与分组的绑定关系
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 再删除账号
|
||||
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
|
||||
return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error
|
||||
}
|
||||
|
||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
var accounts []model.Account
|
||||
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
var accounts []accountModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.Account{})
|
||||
db := r.db.WithContext(ctx).Model(&accountModel{})
|
||||
|
||||
// Apply filters
|
||||
if platform != "" {
|
||||
db = db.Where("platform = ?", platform)
|
||||
}
|
||||
@@ -106,67 +103,105 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 填充每个 Account 的虚拟字段(GroupIDs 和 Groups)
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
|
||||
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
|
||||
for _, ag := range accounts[i].AccountGroups {
|
||||
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return accounts, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return outAccounts, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
var accounts []accountModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
|
||||
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
|
||||
var accounts []accountModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", model.StatusActive).
|
||||
Where("status = ?", service.StatusActive).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
var accounts []accountModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ? AND status = ?", platform, service.StatusActive).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error
|
||||
}
|
||||
|
||||
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var caseSql = "UPDATE accounts SET last_used_at = CASE id"
|
||||
var args []any
|
||||
var ids []int64
|
||||
|
||||
for id, ts := range updates {
|
||||
caseSql += " WHEN ? THEN CAST(? AS TIMESTAMP)"
|
||||
args = append(args, id, ts)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
caseSql += " END WHERE id IN ?"
|
||||
args = append(args, ids)
|
||||
|
||||
return r.db.WithContext(ctx).Exec(caseSql, args...).Error
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"status": model.StatusError,
|
||||
"status": service.StatusError,
|
||||
"error_message": errorMsg,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
ag := &model.AccountGroup{
|
||||
ag := &accountGroupModel{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: priority,
|
||||
@@ -176,131 +211,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
|
||||
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
|
||||
Delete(&model.AccountGroup{}).Error
|
||||
Delete(&accountGroupModel{}).Error
|
||||
}
|
||||
|
||||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
|
||||
var groups []groupModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
|
||||
Where("account_groups.account_id = ?", accountID).
|
||||
Find(&groups).Error
|
||||
return groups, err
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ? AND status = ?", platform, model.StatusActive).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *groupModelToService(&groups[i]))
|
||||
}
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
// 删除现有绑定
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 添加新绑定
|
||||
if len(groupIDs) > 0 {
|
||||
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
|
||||
for i, groupID := range groupIDs {
|
||||
accountGroups = append(accountGroups, model.AccountGroup{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: i + 1, // 使用索引作为优先级
|
||||
})
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(&accountGroups).Error
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
accountGroups := make([]accountGroupModel, 0, len(groupIDs))
|
||||
for i, groupID := range groupIDs {
|
||||
accountGroups = append(accountGroups, accountGroupModel{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: i + 1,
|
||||
})
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(&accountGroups).Error
|
||||
}
|
||||
|
||||
// ListSchedulable 获取所有可调度的账号
|
||||
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
||||
var accounts []accountModel
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND schedulable = ?", model.StatusActive, true).
|
||||
Where("status = ? AND schedulable = ?", service.StatusActive, true).
|
||||
Where("(overload_until IS NULL OR overload_until <= ?)", now).
|
||||
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupID 按组获取可调度的账号
|
||||
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
var accounts []accountModel
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ?", groupID).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
|
||||
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
|
||||
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
// ListSchedulableByPlatform 按平台获取可调度的账号
|
||||
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
var accounts []accountModel
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("platform = ?", platform).
|
||||
Where("status = ? AND schedulable = ?", model.StatusActive, true).
|
||||
Where("status = ? AND schedulable = ?", service.StatusActive, true).
|
||||
Where("(overload_until IS NULL OR overload_until <= ?)", now).
|
||||
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
|
||||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
|
||||
var accounts []model.Account
|
||||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
var accounts []accountModel
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
|
||||
Where("account_groups.group_id = ?", groupID).
|
||||
Where("accounts.platform = ?", platform).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
|
||||
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
|
||||
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
|
||||
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
|
||||
Preload("Proxy").
|
||||
Order("account_groups.priority ASC, accounts.priority ASC").
|
||||
Find(&accounts).Error
|
||||
return accounts, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outAccounts := make([]service.Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
|
||||
}
|
||||
return outAccounts, nil
|
||||
}
|
||||
|
||||
// SetRateLimited 标记账号为限流状态(429)
|
||||
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": now,
|
||||
"rate_limit_reset_at": resetAt,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// SetOverloaded 标记账号为过载状态(529)
|
||||
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
|
||||
Update("overload_until", until).Error
|
||||
}
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"rate_limited_at": nil,
|
||||
"rate_limit_reset_at": nil,
|
||||
@@ -308,7 +360,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 更新账号的5小时时间窗口信息
|
||||
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
updates := map[string]any{
|
||||
"session_window_status": status,
|
||||
@@ -319,45 +370,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
if end != nil {
|
||||
updates["session_window_end"] = end
|
||||
}
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// SetSchedulable 设置账号的调度开关
|
||||
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
|
||||
Update("schedulable", schedulable).Error
|
||||
}
|
||||
|
||||
// UpdateExtra updates specific fields in account's Extra JSONB field
|
||||
// It merges the updates into existing Extra data without overwriting other fields
|
||||
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get current account to preserve existing Extra data
|
||||
var account model.Account
|
||||
var account accountModel
|
||||
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Extra if nil
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(model.JSONB)
|
||||
account.Extra = datatypes.JSONMap{}
|
||||
}
|
||||
|
||||
// Merge updates into existing Extra
|
||||
for k, v := range updates {
|
||||
account.Extra[k] = v
|
||||
}
|
||||
|
||||
// Save updated Extra
|
||||
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
|
||||
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
|
||||
Update("extra", account.Extra).Error
|
||||
}
|
||||
|
||||
// BulkUpdate updates multiple accounts with the provided fields.
|
||||
// It merges credentials/extra JSONB fields instead of overwriting them.
|
||||
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
@@ -381,10 +422,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
updateMap["status"] = *updates.Status
|
||||
}
|
||||
if len(updates.Credentials) > 0 {
|
||||
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
|
||||
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials))
|
||||
}
|
||||
if len(updates.Extra) > 0 {
|
||||
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
|
||||
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra))
|
||||
}
|
||||
|
||||
if len(updateMap) == 0 {
|
||||
@@ -392,10 +433,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&model.Account{}).
|
||||
Model(&accountModel{}).
|
||||
Where("id IN ?", ids).
|
||||
Clauses(clause.Returning{}).
|
||||
Updates(updateMap)
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
type accountModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Name string `gorm:"size:100;not null"`
|
||||
Platform string `gorm:"size:50;not null"`
|
||||
Type string `gorm:"size:20;not null"`
|
||||
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
|
||||
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
|
||||
ProxyID *int64 `gorm:"index"`
|
||||
Concurrency int `gorm:"default:3;not null"`
|
||||
Priority int `gorm:"default:50;not null"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
ErrorMessage string `gorm:"type:text"`
|
||||
LastUsedAt *time.Time `gorm:"index"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
|
||||
Schedulable bool `gorm:"default:true;not null"`
|
||||
|
||||
RateLimitedAt *time.Time `gorm:"index"`
|
||||
RateLimitResetAt *time.Time `gorm:"index"`
|
||||
OverloadUntil *time.Time `gorm:"index"`
|
||||
|
||||
SessionWindowStart *time.Time
|
||||
SessionWindowEnd *time.Time
|
||||
SessionWindowStatus string `gorm:"size:20"`
|
||||
|
||||
Proxy *proxyModel `gorm:"foreignKey:ProxyID"`
|
||||
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"`
|
||||
}
|
||||
|
||||
func (accountModel) TableName() string { return "accounts" }
|
||||
|
||||
type accountGroupModel struct {
|
||||
AccountID int64 `gorm:"primaryKey"`
|
||||
GroupID int64 `gorm:"primaryKey"`
|
||||
Priority int `gorm:"default:50;not null"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
|
||||
Account *accountModel `gorm:"foreignKey:AccountID"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
}
|
||||
|
||||
func (accountGroupModel) TableName() string { return "account_groups" }
|
||||
|
||||
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.AccountGroup{
|
||||
AccountID: m.AccountID,
|
||||
GroupID: m.GroupID,
|
||||
Priority: m.Priority,
|
||||
CreatedAt: m.CreatedAt,
|
||||
Account: accountModelToService(m.Account),
|
||||
Group: groupModelToService(m.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func accountModelToService(m *accountModel) *service.Account {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var credentials map[string]any
|
||||
if m.Credentials != nil {
|
||||
credentials = map[string]any(m.Credentials)
|
||||
}
|
||||
|
||||
var extra map[string]any
|
||||
if m.Extra != nil {
|
||||
extra = map[string]any(m.Extra)
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Platform: m.Platform,
|
||||
Type: m.Type,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: m.ProxyID,
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
Status: m.Status,
|
||||
ErrorMessage: m.ErrorMessage,
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Schedulable: m.Schedulable,
|
||||
RateLimitedAt: m.RateLimitedAt,
|
||||
RateLimitResetAt: m.RateLimitResetAt,
|
||||
OverloadUntil: m.OverloadUntil,
|
||||
SessionWindowStart: m.SessionWindowStart,
|
||||
SessionWindowEnd: m.SessionWindowEnd,
|
||||
SessionWindowStatus: m.SessionWindowStatus,
|
||||
Proxy: proxyModelToService(m.Proxy),
|
||||
}
|
||||
|
||||
if len(m.AccountGroups) > 0 {
|
||||
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
|
||||
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
|
||||
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
|
||||
for i := range m.AccountGroups {
|
||||
ag := accountGroupModelToService(&m.AccountGroups[i])
|
||||
if ag == nil {
|
||||
continue
|
||||
}
|
||||
account.AccountGroups = append(account.AccountGroups, *ag)
|
||||
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
|
||||
if ag.Group != nil {
|
||||
account.Groups = append(account.Groups, ag.Group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return account
|
||||
}
|
||||
|
||||
func accountModelFromService(a *service.Account) *accountModel {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var credentials datatypes.JSONMap
|
||||
if a.Credentials != nil {
|
||||
credentials = datatypes.JSONMap(a.Credentials)
|
||||
}
|
||||
|
||||
var extra datatypes.JSONMap
|
||||
if a.Extra != nil {
|
||||
extra = datatypes.JSONMap(a.Extra)
|
||||
}
|
||||
|
||||
return &accountModel{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
Priority: a.Priority,
|
||||
Status: a.Status,
|
||||
ErrorMessage: a.ErrorMessage,
|
||||
LastUsedAt: a.LastUsedAt,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
Schedulable: a.Schedulable,
|
||||
RateLimitedAt: a.RateLimitedAt,
|
||||
RateLimitResetAt: a.RateLimitResetAt,
|
||||
OverloadUntil: a.OverloadUntil,
|
||||
SessionWindowStart: a.SessionWindowStart,
|
||||
SessionWindowEnd: a.SessionWindowEnd,
|
||||
SessionWindowStatus: a.SessionWindowStatus,
|
||||
}
|
||||
}
|
||||
|
||||
func applyAccountModelToService(account *service.Account, m *accountModel) {
|
||||
if account == nil || m == nil {
|
||||
return
|
||||
}
|
||||
account.ID = m.ID
|
||||
account.CreatedAt = m.CreatedAt
|
||||
account.UpdatedAt = m.UpdatedAt
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) {
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *AccountRepoSuite) TestCreate() {
|
||||
account := &model.Account{
|
||||
Name: "test-create",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Type: model.AccountTypeOAuth,
|
||||
Status: model.StatusActive,
|
||||
account := &service.Account{
|
||||
Name: "test-create",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{},
|
||||
Extra: map[string]any{},
|
||||
Concurrency: 3,
|
||||
Priority: 50,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, account)
|
||||
@@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"})
|
||||
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
|
||||
|
||||
account.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
@@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete should cascade remove bindings")
|
||||
|
||||
var count int64
|
||||
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count)
|
||||
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
|
||||
s.Require().Zero(count, "expected bindings to be removed")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *AccountRepoSuite) TestList() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
|
||||
|
||||
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
status string
|
||||
search string
|
||||
wantCount int
|
||||
validate func(accounts []model.Account)
|
||||
validate func(accounts []service.Account)
|
||||
}{
|
||||
{
|
||||
name: "filter_by_platform",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
|
||||
},
|
||||
platform: model.PlatformOpenAI,
|
||||
platform: service.PlatformOpenAI,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform)
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_type",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
|
||||
},
|
||||
accType: model.AccountTypeApiKey,
|
||||
accType: service.AccountTypeApiKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type)
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
|
||||
},
|
||||
status: model.StatusDisabled,
|
||||
status: service.StatusDisabled,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
s.Require().Equal(model.StatusDisabled, accounts[0].Status)
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_search",
|
||||
setup: func(db *gorm.DB) {
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"})
|
||||
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
|
||||
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
|
||||
},
|
||||
search: "alpha",
|
||||
wantCount: 1,
|
||||
validate: func(accounts []model.Account) {
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Contains(accounts[0].Name, "alpha")
|
||||
},
|
||||
},
|
||||
@@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
// --- ListByGroup / ListActive / ListByPlatform ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListByGroup() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive})
|
||||
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
|
||||
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
|
||||
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
|
||||
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
|
||||
|
||||
@@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListActive() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
|
||||
|
||||
accounts, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
@@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListByPlatform() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
|
||||
|
||||
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListByPlatform")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
|
||||
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// --- Preload and VirtualFields ---
|
||||
|
||||
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
|
||||
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "acc1",
|
||||
ProxyID: &proxy.ID,
|
||||
})
|
||||
@@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
|
||||
|
||||
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
@@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
|
||||
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
|
||||
@@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulable() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||
|
||||
sched, err := s.repo.ListSchedulable(s.ctx)
|
||||
@@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() {
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
|
||||
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||
|
||||
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true})
|
||||
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
|
||||
|
||||
@@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
|
||||
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic)
|
||||
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(a1.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
@@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
|
||||
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
@@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
|
||||
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
|
||||
@@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
|
||||
until := time.Now().Add(1 * time.Hour)
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
|
||||
@@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
// --- UpdateLastUsed ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
|
||||
s.Require().Nil(account.LastUsedAt)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
|
||||
@@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
// --- SetError ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetError() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(model.StatusError, got.Status)
|
||||
s.Require().Equal(service.StatusError, got.Status)
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
|
||||
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
|
||||
|
||||
@@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
// --- UpdateExtra ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "acc-extra",
|
||||
Extra: model.JSONB{"a": "1"},
|
||||
Extra: datatypes.JSONMap{"a": "1"},
|
||||
})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
|
||||
|
||||
@@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
@@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
crsID := "crs-12345"
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "acc-crs",
|
||||
Extra: model.JSONB{"crs_account_id": crsID},
|
||||
Extra: datatypes.JSONMap{"crs_account_id": crsID},
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
|
||||
@@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
|
||||
// --- BulkUpdate ---
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
|
||||
|
||||
newPriority := 99
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
|
||||
@@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "bulk-cred",
|
||||
Credentials: model.JSONB{"existing": "value"},
|
||||
Credentials: datatypes.JSONMap{"existing": "value"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Credentials: model.JSONB{"new_key": "new_value"},
|
||||
Credentials: datatypes.JSONMap{"new_key": "new_value"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
@@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
|
||||
Name: "bulk-extra",
|
||||
Extra: model.JSONB{"existing": "val"},
|
||||
Extra: datatypes.JSONMap{"existing": "val"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Extra: model.JSONB{"new_key": "new_val"},
|
||||
Extra: datatypes.JSONMap{"new_key": "new_val"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
@@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
|
||||
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func idsOfAccounts(accounts []model.Account) []int64 {
|
||||
func idsOfAccounts(accounts []service.Account) []int64 {
|
||||
out := make([]int64, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
out = append(out, accounts[i].ID)
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +15,11 @@ const (
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
func apiKeyRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -23,12 +29,16 @@ func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
@@ -37,7 +47,7 @@ func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID in
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
|
||||
@@ -23,13 +23,14 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
name: "missing_key_returns_zero_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
_, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key")
|
||||
require.NoError(s.T(), err, "expected nil error for missing key")
|
||||
require.Equal(s.T(), 0, count, "expected zero count for missing key")
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -58,8 +59,9 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
|
||||
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
|
||||
|
||||
_, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete")
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "expected nil error after delete")
|
||||
require.Equal(s.T(), 0, count, "expected zero count after delete")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
46
backend/internal/repository/api_key_cache_test.go
Normal file
46
backend/internal/repository/api_key_cache_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApiKeyRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "apikey:ratelimit:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "apikey:ratelimit:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "apikey:ratelimit:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "apikey:ratelimit:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := apiKeyRateLimitKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,10 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
|
||||
err := r.db.WithContext(ctx).Create(key).Error
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
m := apiKeyModelFromService(key)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyApiKeyModelToService(key, m)
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
var key model.ApiKey
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
var m apiKeyModel
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
}
|
||||
return &key, nil
|
||||
return apiKeyModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
var m apiKeyModel
|
||||
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
|
||||
}
|
||||
return &apiKey, nil
|
||||
return apiKeyModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
|
||||
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
m := apiKeyModelFromService(key)
|
||||
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
|
||||
if err == nil {
|
||||
applyApiKeyModelToService(key, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
|
||||
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []apiKeyModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID)
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
@@ -64,36 +73,47 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
}
|
||||
|
||||
return keys, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return outKeys, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(apiKeyIDs))
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&apiKeyModel{}).
|
||||
Where("user_id = ? AND id IN ?", userID, apiKeyIDs).
|
||||
Pluck("id", &ids).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []model.ApiKey
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
var keys []apiKeyModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID)
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
@@ -103,24 +123,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
}
|
||||
|
||||
return keys, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return outKeys, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||
var keys []model.ApiKey
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
var keys []apiKeyModel
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
|
||||
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
|
||||
|
||||
if userID > 0 {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
@@ -135,12 +150,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
|
||||
}
|
||||
return outKeys, nil
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
|
||||
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Update("group_id", nil)
|
||||
return result.RowsAffected, result.Error
|
||||
@@ -149,6 +168,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
type apiKeyModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
UserID int64 `gorm:"index;not null"`
|
||||
Key string `gorm:"uniqueIndex;size:128;not null"`
|
||||
Name string `gorm:"size:100;not null"`
|
||||
GroupID *int64 `gorm:"index"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
|
||||
User *userModel `gorm:"foreignKey:UserID"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
}
|
||||
|
||||
func (apiKeyModel) TableName() string { return "api_keys" }
|
||||
|
||||
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.ApiKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
GroupID: m.GroupID,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
User: userModelToService(m.User),
|
||||
Group: groupModelToService(m.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &apiKeyModel{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
|
||||
if key == nil || m == nil {
|
||||
return
|
||||
}
|
||||
key.ID = m.ID
|
||||
key.CreatedAt = m.CreatedAt
|
||||
key.UpdatedAt = m.UpdatedAt
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
|
||||
|
||||
key := &model.ApiKey{
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
Status: model.StatusActive,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, key)
|
||||
@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
|
||||
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: model.StatusActive,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
// --- Update ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
|
||||
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
Status: service.StatusActive,
|
||||
}))
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = model.StatusDisabled
|
||||
key.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
s.Require().Equal("sk-update", got.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got.Name)
|
||||
s.Require().Equal(model.StatusDisabled, got.Status)
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
|
||||
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
}))
|
||||
|
||||
key.GroupID = nil
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
// --- Delete ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
|
||||
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
|
||||
for i := 0; i < 5; i++ {
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-page-" + string(rune('a'+i)),
|
||||
Name: "Key",
|
||||
@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
|
||||
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
|
||||
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
// --- SearchApiKeys ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
}
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
|
||||
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
|
||||
|
||||
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
|
||||
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
|
||||
|
||||
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-test-1",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: model.StatusActive,
|
||||
})
|
||||
Status: service.StatusActive,
|
||||
}))
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = model.StatusDisabled
|
||||
key.Status = service.StatusDisabled
|
||||
key.GroupID = nil
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
|
||||
|
||||
@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got2.Name)
|
||||
s.Require().Equal(model.StatusDisabled, got2.Status)
|
||||
s.Require().Equal(service.StatusDisabled, got2.Status)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
|
||||
UserID: user.ID,
|
||||
Key: "sk-test-2",
|
||||
Name: "Group Key",
|
||||
|
||||
49
backend/internal/repository/auto_migrate.go
Normal file
49
backend/internal/repository/auto_migrate.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// MaxExpiresAt is the maximum allowed expiration date for subscriptions (year 2099)
|
||||
// This prevents time.Time JSON serialization errors (RFC 3339 requires year <= 9999)
|
||||
var maxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
|
||||
|
||||
// AutoMigrate runs schema migrations for all repository persistence models.
|
||||
// Persistence models are defined within individual `*_repo.go` files.
|
||||
func AutoMigrate(db *gorm.DB) error {
|
||||
err := db.AutoMigrate(
|
||||
&userModel{},
|
||||
&apiKeyModel{},
|
||||
&groupModel{},
|
||||
&accountModel{},
|
||||
&accountGroupModel{},
|
||||
&proxyModel{},
|
||||
&redeemCodeModel{},
|
||||
&usageLogModel{},
|
||||
&settingModel{},
|
||||
&userSubscriptionModel{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
|
||||
return fixInvalidExpiresAt(db)
|
||||
}
|
||||
|
||||
// fixInvalidExpiresAt 修复 user_subscriptions 表中无效的过期时间
|
||||
func fixInvalidExpiresAt(db *gorm.DB) error {
|
||||
result := db.Model(&userSubscriptionModel{}).
|
||||
Where("expires_at > ?", maxExpiresAt).
|
||||
Update("expires_at", maxExpiresAt)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected > 0 {
|
||||
log.Printf("[AutoMigrate] Fixed %d subscriptions with invalid expires_at (year > 2099)", result.RowsAffected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -18,6 +18,16 @@ const (
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// billingSubKey generates the Redis key for subscription cache.
|
||||
func billingSubKey(userID, groupID int64) string {
|
||||
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
}
|
||||
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
@@ -62,7 +72,7 @@ func NewBillingCache(rdb *redis.Client) service.BillingCache {
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
key := billingBalanceKey(userID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -71,12 +81,12 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
@@ -85,12 +95,12 @@ func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amou
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
key := billingSubKey(userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -140,7 +150,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
return nil
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
key := billingSubKey(userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
subFieldStatus: data.Status,
|
||||
@@ -159,7 +169,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
@@ -168,6 +178,6 @@ func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, grou
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
key := billingSubKey(userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
87
backend/internal/repository/billing_cache_test.go
Normal file
87
backend/internal/repository/billing_cache_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillingBalanceKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "billing:balance:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "billing:balance:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "billing:balance:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "billing:balance:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingBalanceKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingSubKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
groupID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_ids",
|
||||
userID: 123,
|
||||
groupID: 456,
|
||||
expected: "billing:sub:123:456",
|
||||
},
|
||||
{
|
||||
name: "zero_ids",
|
||||
userID: 0,
|
||||
groupID: 0,
|
||||
expected: "billing:sub:0:0",
|
||||
},
|
||||
{
|
||||
name: "negative_ids",
|
||||
userID: -1,
|
||||
groupID: -2,
|
||||
expected: "billing:sub:-1:-2",
|
||||
},
|
||||
{
|
||||
name: "max_int64_ids",
|
||||
userID: math.MaxInt64,
|
||||
groupID: math.MaxInt64,
|
||||
expected: "billing:sub:9223372036854775807:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingSubKey(tc.userID, tc.groupID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
// Parse code which may contain state in format "authCode#state"
|
||||
@@ -168,6 +168,11 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
reqBody["state"] = codeState
|
||||
}
|
||||
|
||||
// Setup token requires longer expiration (1 year)
|
||||
if isSetupToken {
|
||||
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
|
||||
}
|
||||
|
||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
@@ -199,16 +204,20 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauth.ClientID)
|
||||
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
|
||||
// Anthropic OAuth API 期望 JSON 格式的请求体
|
||||
reqBody := map[string]any{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
"client_id": oauth.ClientID,
|
||||
}
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -34,7 +33,6 @@ type requestCapture struct {
|
||||
method string
|
||||
cookies []*http.Cookie
|
||||
body []byte
|
||||
formValues url.Values
|
||||
bodyJSON map[string]any
|
||||
contentType string
|
||||
}
|
||||
@@ -193,12 +191,13 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
code string
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
code string
|
||||
isSetupToken bool
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_state_when_embedded",
|
||||
@@ -212,7 +211,8 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
Scope: "s",
|
||||
})
|
||||
},
|
||||
code: "AUTH#STATE2",
|
||||
code: "AUTH#STATE2",
|
||||
isSetupToken: false,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
@@ -225,6 +225,29 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
|
||||
// Regular OAuth should not include expires_in
|
||||
require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "setup_token_includes_expires_in",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 31536000,
|
||||
})
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: true,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
// Setup token should include expires_in with 1 year value
|
||||
require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"],
|
||||
"setup token should include expires_in: 31536000")
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -233,8 +256,9 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
},
|
||||
code: "AUTH",
|
||||
wantErr: true,
|
||||
code: "AUTH",
|
||||
isSetupToken: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -256,7 +280,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
s.client = client
|
||||
s.client.tokenURL = s.srv.URL
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "")
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
@@ -282,24 +306,53 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_form",
|
||||
name: "sends_json_format",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "new_refresh_token",
|
||||
Scope: "user:profile user:inference",
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
RefreshToken: "new_refresh_token",
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
|
||||
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
|
||||
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
|
||||
// 验证使用 JSON 格式(不是 form 格式)
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
|
||||
"expected JSON content-type, got: %s", captured.contentType)
|
||||
// 验证 JSON body 内容
|
||||
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
|
||||
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns_new_refresh_token",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rotated_rt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -311,8 +364,9 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
captured.formValues, _ = url.ParseQuery(string(captured.body))
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
@@ -331,6 +385,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,11 @@ import (
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -20,7 +25,7 @@ func NewEmailCache(rdb *redis.Client) service.EmailCache {
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
key := verifyCodeKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -33,7 +38,7 @@ func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*se
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
key := verifyCodeKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -42,6 +47,6 @@ func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data
|
||||
}
|
||||
|
||||
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
key := verifyCodeKeyPrefix + email
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
45
backend/internal/repository/email_cache_test.go
Normal file
45
backend/internal/repository/email_cache_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyCodeKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_email",
|
||||
email: "user@example.com",
|
||||
expected: "verify_code:user@example.com",
|
||||
},
|
||||
{
|
||||
name: "empty_email",
|
||||
email: "",
|
||||
expected: "verify_code:",
|
||||
},
|
||||
{
|
||||
name: "email_with_plus",
|
||||
email: "user+tag@example.com",
|
||||
expected: "verify_code:user+tag@example.com",
|
||||
},
|
||||
{
|
||||
name: "email_with_special_chars",
|
||||
email: "user.name+tag@sub.domain.com",
|
||||
expected: "verify_code:user.name+tag@sub.domain.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := verifyCodeKey(tc.email)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,21 +6,25 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
|
||||
func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
|
||||
t.Helper()
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
if u.Role == "" {
|
||||
u.Role = model.RoleUser
|
||||
u.Role = service.RoleUser
|
||||
}
|
||||
if u.Status == "" {
|
||||
u.Status = model.StatusActive
|
||||
u.Status = service.StatusActive
|
||||
}
|
||||
if u.Concurrency == 0 {
|
||||
u.Concurrency = 5
|
||||
}
|
||||
if u.CreatedAt.IsZero() {
|
||||
u.CreatedAt = time.Now()
|
||||
@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
|
||||
return u
|
||||
}
|
||||
|
||||
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
|
||||
func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
|
||||
t.Helper()
|
||||
if g.Platform == "" {
|
||||
g.Platform = model.PlatformAnthropic
|
||||
g.Platform = service.PlatformAnthropic
|
||||
}
|
||||
if g.Status == "" {
|
||||
g.Status = model.StatusActive
|
||||
g.Status = service.StatusActive
|
||||
}
|
||||
if g.SubscriptionType == "" {
|
||||
g.SubscriptionType = model.SubscriptionTypeStandard
|
||||
g.SubscriptionType = service.SubscriptionTypeStandard
|
||||
}
|
||||
if g.CreatedAt.IsZero() {
|
||||
g.CreatedAt = time.Now()
|
||||
@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
|
||||
return g
|
||||
}
|
||||
|
||||
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
|
||||
func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
|
||||
t.Helper()
|
||||
if p.Protocol == "" {
|
||||
p.Protocol = "http"
|
||||
@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
|
||||
p.Port = 8080
|
||||
}
|
||||
if p.Status == "" {
|
||||
p.Status = model.StatusActive
|
||||
p.Status = service.StatusActive
|
||||
}
|
||||
if p.CreatedAt.IsZero() {
|
||||
p.CreatedAt = time.Now()
|
||||
@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
|
||||
return p
|
||||
}
|
||||
|
||||
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account {
|
||||
func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
|
||||
t.Helper()
|
||||
if a.Platform == "" {
|
||||
a.Platform = model.PlatformAnthropic
|
||||
a.Platform = service.PlatformAnthropic
|
||||
}
|
||||
if a.Type == "" {
|
||||
a.Type = model.AccountTypeOAuth
|
||||
a.Type = service.AccountTypeOAuth
|
||||
}
|
||||
if a.Status == "" {
|
||||
a.Status = model.StatusActive
|
||||
a.Status = service.StatusActive
|
||||
}
|
||||
if !a.Schedulable {
|
||||
a.Schedulable = true
|
||||
}
|
||||
if a.Credentials == nil {
|
||||
a.Credentials = model.JSONB{}
|
||||
a.Credentials = datatypes.JSONMap{}
|
||||
}
|
||||
if a.Extra == nil {
|
||||
a.Extra = model.JSONB{}
|
||||
a.Extra = datatypes.JSONMap{}
|
||||
}
|
||||
if a.CreatedAt.IsZero() {
|
||||
a.CreatedAt = time.Now()
|
||||
@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey {
|
||||
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
|
||||
t.Helper()
|
||||
if k.Status == "" {
|
||||
k.Status = model.StatusActive
|
||||
k.Status = service.StatusActive
|
||||
}
|
||||
if k.CreatedAt.IsZero() {
|
||||
k.CreatedAt = time.Now()
|
||||
@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
|
||||
return k
|
||||
}
|
||||
|
||||
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode {
|
||||
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
|
||||
t.Helper()
|
||||
if c.Status == "" {
|
||||
c.Status = model.StatusUnused
|
||||
c.Status = service.StatusUnused
|
||||
}
|
||||
if c.Type == "" {
|
||||
c.Type = model.RedeemTypeBalance
|
||||
c.Type = service.RedeemTypeBalance
|
||||
}
|
||||
if c.CreatedAt.IsZero() {
|
||||
c.CreatedAt = time.Now()
|
||||
@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
|
||||
return c
|
||||
}
|
||||
|
||||
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription {
|
||||
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
|
||||
t.Helper()
|
||||
if s.Status == "" {
|
||||
s.Status = model.SubscriptionStatusActive
|
||||
s.Status = service.SubscriptionStatusActive
|
||||
}
|
||||
now := time.Now()
|
||||
if s.StartsAt.IsZero() {
|
||||
@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
|
||||
|
||||
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
|
||||
t.Helper()
|
||||
require.NoError(t, db.Create(&model.AccountGroup{
|
||||
require.NoError(t, db.Create(&accountGroupModel{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Priority: priority,
|
||||
CreatedAt: time.Now(),
|
||||
}).Error, "create account_group")
|
||||
}
|
||||
|
||||
117
backend/internal/repository/gemini_oauth_client.go
Normal file
117
backend/internal/repository/gemini_oauth_client.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type geminiOAuthClient struct {
|
||||
tokenURL string
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
|
||||
return &geminiOAuthClient{
|
||||
tokenURL: geminicli.TokenURL,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
|
||||
client := createGeminiReqClient(proxyURL)
|
||||
|
||||
// Use different OAuth clients based on oauthType:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
|
||||
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", oauthCfg.ClientID)
|
||||
formData.Set("client_secret", oauthCfg.ClientSecret)
|
||||
formData.Set("code", code)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
|
||||
var tokenResp geminicli.TokenResponse
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(c.tokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
|
||||
client := createGeminiReqClient(proxyURL)
|
||||
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" {
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
|
||||
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauthCfg.ClientID)
|
||||
formData.Set("client_secret", oauthCfg.ClientSecret)
|
||||
|
||||
var tokenResp geminicli.TokenResponse
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(c.tokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createGeminiReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().SetTimeout(60 * time.Second)
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
return client
|
||||
}
|
||||
44
backend/internal/repository/gemini_token_cache.go
Normal file
44
backend/internal/repository/gemini_token_cache.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
geminiTokenKeyPrefix = "gemini:token:"
|
||||
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
|
||||
)
|
||||
|
||||
type geminiTokenCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
|
||||
return &geminiTokenCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Set(ctx, key, token, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
105
backend/internal/repository/geminicli_codeassist_client.go
Normal file
105
backend/internal/repository/geminicli_codeassist_client.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type geminiCliCodeAssistClient struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewGeminiCliCodeAssistClient() service.GeminiCliCodeAssistClient {
|
||||
return &geminiCliCodeAssistClient{baseURL: geminicli.GeminiCliBaseURL}
|
||||
}
|
||||
|
||||
func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) {
|
||||
if reqBody == nil {
|
||||
reqBody = defaultLoadCodeAssistRequest()
|
||||
}
|
||||
|
||||
var out geminicli.LoadCodeAssistResponse
|
||||
resp, err := createGeminiCliReqClient(proxyURL).R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&out).
|
||||
Post(c.baseURL + "/v1internal:loadCodeAssist")
|
||||
if err != nil {
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist request error: %v\n", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) {
|
||||
if reqBody == nil {
|
||||
reqBody = defaultOnboardUserRequest()
|
||||
}
|
||||
|
||||
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
|
||||
|
||||
var out geminicli.OnboardUserResponse
|
||||
resp, err := createGeminiCliReqClient(proxyURL).R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&out).
|
||||
Post(c.baseURL + "/v1internal:onboardUser")
|
||||
if err != nil {
|
||||
fmt.Printf("[CodeAssist] OnboardUser request error: %v\n", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func createGeminiCliReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
|
||||
return &geminicli.LoadCodeAssistRequest{
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
IDEType: "ANTIGRAVITY",
|
||||
Platform: "PLATFORM_UNSPECIFIED",
|
||||
PluginType: "GEMINI",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultOnboardUserRequest() *geminicli.OnboardUserRequest {
|
||||
return &geminicli.OnboardUserRequest{
|
||||
TierID: "LEGACY",
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
IDEType: "ANTIGRAVITY",
|
||||
Platform: "PLATFORM_UNSPECIFIED",
|
||||
PluginType: "GEMINI",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,10 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
|
||||
return &groupRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
|
||||
err := r.db.WithContext(ctx).Create(group).Error
|
||||
func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
|
||||
m := groupModelFromService(group)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyGroupModelToService(group, m)
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
var group model.Group
|
||||
err := r.db.WithContext(ctx).First(&group, id).Error
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
var m groupModel
|
||||
err := r.db.WithContext(ctx).First(&m, id).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
return &group, nil
|
||||
group := groupModelToService(&m)
|
||||
count, _ := r.GetAccountCount(ctx, group.ID)
|
||||
group.AccountCount = count
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
|
||||
return r.db.WithContext(ctx).Save(group).Error
|
||||
func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
|
||||
m := groupModelFromService(group)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyGroupModelToService(group, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
|
||||
return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", nil)
|
||||
}
|
||||
|
||||
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
var groups []model.Group
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
var groups []groupModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.Group{})
|
||||
db := r.db.WithContext(ctx).Model(&groupModel{})
|
||||
|
||||
// Apply filters
|
||||
if platform != "" {
|
||||
@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 获取每个分组的账号数量
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
count, _ := r.GetAccountCount(ctx, groups[i].ID)
|
||||
groups[i].AccountCount = count
|
||||
outGroups = append(outGroups, *groupModelToService(&groups[i]))
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
// 获取每个分组的账号数量
|
||||
for i := range outGroups {
|
||||
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
|
||||
outGroups[i].AccountCount = count
|
||||
}
|
||||
|
||||
return groups, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return outGroups, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
var groups []groupModel
|
||||
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 获取每个分组的账号数量
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
count, _ := r.GetAccountCount(ctx, groups[i].ID)
|
||||
groups[i].AccountCount = count
|
||||
outGroups = append(outGroups, *groupModelToService(&groups[i]))
|
||||
}
|
||||
return groups, nil
|
||||
// 获取每个分组的账号数量
|
||||
for i := range outGroups {
|
||||
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
|
||||
outGroups[i].AccountCount = count
|
||||
}
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
|
||||
var groups []model.Group
|
||||
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
var groups []groupModel
|
||||
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 获取每个分组的账号数量
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
count, _ := r.GetAccountCount(ctx, groups[i].ID)
|
||||
groups[i].AccountCount = count
|
||||
outGroups = append(outGroups, *groupModelToService(&groups[i]))
|
||||
}
|
||||
return groups, nil
|
||||
// 获取每个分组的账号数量
|
||||
for i := range outGroups {
|
||||
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
|
||||
outGroups[i].AccountCount = count
|
||||
}
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
|
||||
err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
|
||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
|
||||
result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
|
||||
var affectedUserIDs []int64
|
||||
if group.IsSubscriptionType() {
|
||||
var subscriptions []model.UserSubscription
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&model.UserSubscription{}).
|
||||
Table("user_subscriptions").
|
||||
Where("group_id = ?", id).
|
||||
Select("user_id").
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
Pluck("user_id", &affectedUserIDs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, sub := range subscriptions {
|
||||
affectedUserIDs = append(affectedUserIDs, sub.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 删除订阅类型分组的订阅记录
|
||||
if group.IsSubscriptionType() {
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
|
||||
if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
|
||||
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
|
||||
if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3. 从 users.allowed_groups 数组中移除该分组 ID
|
||||
if err := tx.Model(&model.User{}).
|
||||
Where("? = ANY(allowed_groups)", id).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
|
||||
if err := tx.Exec(
|
||||
"UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
|
||||
id, id,
|
||||
).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. 删除 account_groups 中间表的数据
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
|
||||
if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 5. 删除分组本身(带锁,避免并发写)
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
|
||||
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
type groupModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Name string `gorm:"uniqueIndex;size:100;not null"`
|
||||
Description string `gorm:"type:text"`
|
||||
Platform string `gorm:"size:50;default:anthropic;not null"`
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
|
||||
IsExclusive bool `gorm:"default:false;not null"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
|
||||
SubscriptionType string `gorm:"size:20;default:standard;not null"`
|
||||
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
|
||||
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
|
||||
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
|
||||
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (groupModel) TableName() string { return "groups" }
|
||||
|
||||
func groupModelToService(m *groupModel) *service.Group {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Description: m.Description,
|
||||
Platform: m.Platform,
|
||||
RateMultiplier: m.RateMultiplier,
|
||||
IsExclusive: m.IsExclusive,
|
||||
Status: m.Status,
|
||||
SubscriptionType: m.SubscriptionType,
|
||||
DailyLimitUSD: m.DailyLimitUSD,
|
||||
WeeklyLimitUSD: m.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: m.MonthlyLimitUSD,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func groupModelFromService(sg *service.Group) *groupModel {
|
||||
if sg == nil {
|
||||
return nil
|
||||
}
|
||||
return &groupModel{
|
||||
ID: sg.ID,
|
||||
Name: sg.Name,
|
||||
Description: sg.Description,
|
||||
Platform: sg.Platform,
|
||||
RateMultiplier: sg.RateMultiplier,
|
||||
IsExclusive: sg.IsExclusive,
|
||||
Status: sg.Status,
|
||||
SubscriptionType: sg.SubscriptionType,
|
||||
DailyLimitUSD: sg.DailyLimitUSD,
|
||||
WeeklyLimitUSD: sg.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: sg.MonthlyLimitUSD,
|
||||
CreatedAt: sg.CreatedAt,
|
||||
UpdatedAt: sg.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyGroupModelToService(group *service.Group, m *groupModel) {
|
||||
if group == nil || m == nil {
|
||||
return
|
||||
}
|
||||
group.ID = m.ID
|
||||
group.CreatedAt = m.CreatedAt
|
||||
group.UpdatedAt = m.UpdatedAt
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *GroupRepoSuite) TestCreate() {
|
||||
group := &model.Group{
|
||||
group := &service.Group{
|
||||
Name: "test-create",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, group)
|
||||
@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"})
|
||||
group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
|
||||
|
||||
group.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, group)
|
||||
@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *GroupRepoSuite) TestList() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
|
||||
|
||||
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform)
|
||||
s.Require().Equal(service.PlatformOpenAI, groups[0].Platform)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil)
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(model.StatusDisabled, groups[0].Status)
|
||||
s.Require().Equal(service.StatusDisabled, groups[0].Status)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
||||
@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{
|
||||
Name: "g1",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{
|
||||
Name: "g2",
|
||||
Platform: model.PlatformAnthropic,
|
||||
Status: model.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
IsExclusive: true,
|
||||
})
|
||||
|
||||
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
|
||||
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive)
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(groups, 1)
|
||||
@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
// --- ListActive / ListActiveByPlatform ---
|
||||
|
||||
func (s *GroupRepoSuite) TestListActive() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
|
||||
|
||||
groups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
|
||||
|
||||
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic)
|
||||
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListActiveByPlatform")
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal("g1", groups[0].Name)
|
||||
@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||
// --- ExistsByName ---
|
||||
|
||||
func (s *GroupRepoSuite) TestExistsByName() {
|
||||
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"})
|
||||
mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
|
||||
|
||||
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
|
||||
s.Require().NoError(err, "ExistsByName")
|
||||
@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
|
||||
// --- GetAccountCount ---
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||
|
||||
@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
// --- DeleteAccountGroupsByGroupID ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
|
||||
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
|
||||
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
|
||||
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
|
||||
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"})
|
||||
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
|
||||
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
|
||||
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
|
||||
a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
|
||||
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
|
||||
|
||||
@@ -15,6 +15,11 @@ const (
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// fingerprintKey generates the Redis key for account fingerprint cache.
|
||||
func fingerprintKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -24,7 +29,7 @@ func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
|
||||
}
|
||||
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
key := fingerprintKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -37,7 +42,7 @@ func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*s
|
||||
}
|
||||
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
|
||||
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
key := fingerprintKey(accountID)
|
||||
val, err := json.Marshal(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
46
backend/internal/repository/identity_cache_test.go
Normal file
46
backend/internal/repository/identity_cache_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_account_id",
|
||||
accountID: 123,
|
||||
expected: "fingerprint:123",
|
||||
},
|
||||
{
|
||||
name: "zero_account_id",
|
||||
accountID: 0,
|
||||
expected: "fingerprint:0",
|
||||
},
|
||||
{
|
||||
name: "negative_account_id",
|
||||
accountID: -1,
|
||||
expected: "fingerprint:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
accountID: math.MaxInt64,
|
||||
expected: "fingerprint:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := fingerprintKey(tc.accountID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
|
||||
log.Printf("failed to open gorm db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := model.AutoMigrate(integrationDB); err != nil {
|
||||
if err := AutoMigrate(integrationDB); err != nil {
|
||||
log.Printf("failed to automigrate db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
16
backend/internal/repository/pagination.go
Normal file
16
backend/internal/repository/pagination.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package repository
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
return &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}
|
||||
}
|
||||
@@ -120,10 +120,9 @@ func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-block
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
@@ -136,7 +135,6 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
close(block)
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
|
||||
@@ -2,10 +2,10 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
|
||||
return &proxyRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
|
||||
return r.db.WithContext(ctx).Create(proxy).Error
|
||||
func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error {
|
||||
m := proxyModelFromService(proxy)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyProxyModelToService(proxy, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
var proxy model.Proxy
|
||||
err := r.db.WithContext(ctx).First(&proxy, id).Error
|
||||
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||
var m proxyModel
|
||||
err := r.db.WithContext(ctx).First(&m, id).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
|
||||
}
|
||||
return &proxy, nil
|
||||
return proxyModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
|
||||
return r.db.WithContext(ctx).Save(proxy).Error
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error {
|
||||
m := proxyModelFromService(proxy)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyProxyModelToService(proxy, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
|
||||
return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error
|
||||
}
|
||||
|
||||
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
var proxies []model.Proxy
|
||||
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||
var proxies []proxyModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.Proxy{})
|
||||
db := r.db.WithContext(ctx).Model(&proxyModel{})
|
||||
|
||||
// Apply filters
|
||||
if protocol != "" {
|
||||
@@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
outProxies := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
|
||||
}
|
||||
|
||||
return proxies, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return outProxies, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
var proxies []model.Proxy
|
||||
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
|
||||
return proxies, err
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
|
||||
var proxies []proxyModel
|
||||
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outProxies := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
|
||||
}
|
||||
return outProxies, nil
|
||||
}
|
||||
|
||||
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
|
||||
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
|
||||
err := r.db.WithContext(ctx).Model(&proxyModel{}).
|
||||
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
@@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.Account{}).
|
||||
err := r.db.WithContext(ctx).Table("accounts").
|
||||
Where("proxy_id = ?", proxyID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
@@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
|
||||
}
|
||||
var results []result
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&model.Account{}).
|
||||
Table("accounts").
|
||||
Select("proxy_id, COUNT(*) as count").
|
||||
Where("proxy_id IS NOT NULL").
|
||||
Group("proxy_id").
|
||||
@@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
|
||||
}
|
||||
|
||||
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
|
||||
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
|
||||
var proxies []model.Proxy
|
||||
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||
var proxies []proxyModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", model.StatusActive).
|
||||
Where("status = ?", service.StatusActive).
|
||||
Order("created_at DESC").
|
||||
Find(&proxies).Error
|
||||
if err != nil {
|
||||
@@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod
|
||||
}
|
||||
|
||||
// Build result with account counts
|
||||
result := make([]model.ProxyWithAccountCount, len(proxies))
|
||||
for i, proxy := range proxies {
|
||||
result[i] = model.ProxyWithAccountCount{
|
||||
Proxy: proxy,
|
||||
AccountCount: counts[proxy.ID],
|
||||
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
proxy := proxyModelToService(&proxies[i])
|
||||
if proxy == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, service.ProxyWithAccountCount{
|
||||
Proxy: *proxy,
|
||||
AccountCount: counts[proxy.ID],
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type proxyModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Name string `gorm:"size:100;not null"`
|
||||
Protocol string `gorm:"size:20;not null"`
|
||||
Host string `gorm:"size:255;not null"`
|
||||
Port int `gorm:"not null"`
|
||||
Username string `gorm:"size:100"`
|
||||
Password string `gorm:"size:100"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (proxyModel) TableName() string { return "proxies" }
|
||||
|
||||
func proxyModelToService(m *proxyModel) *service.Proxy {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Proxy{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Protocol: m.Protocol,
|
||||
Host: m.Host,
|
||||
Port: m.Port,
|
||||
Username: m.Username,
|
||||
Password: m.Password,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func proxyModelFromService(p *service.Proxy) *proxyModel {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return &proxyModel{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
Protocol: p.Protocol,
|
||||
Host: p.Host,
|
||||
Port: p.Port,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
Status: p.Status,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) {
|
||||
if proxy == nil || m == nil {
|
||||
return
|
||||
}
|
||||
proxy.ID = m.ID
|
||||
proxy.CreatedAt = m.CreatedAt
|
||||
proxy.UpdatedAt = m.UpdatedAt
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) {
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCreate() {
|
||||
proxy := &model.Proxy{
|
||||
proxy := &service.Proxy{
|
||||
Name: "test-create",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: model.StatusActive,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, proxy)
|
||||
@@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestUpdate() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
|
||||
proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"}))
|
||||
|
||||
proxy.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, proxy)
|
||||
@@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestDelete() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
|
||||
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() {
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestList() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
|
||||
|
||||
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
|
||||
s.Require().NoError(err)
|
||||
@@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
|
||||
s.Require().Equal(service.StatusDisabled, proxies[0].Status)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
|
||||
s.Require().NoError(err)
|
||||
@@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||
// --- ListActive ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActive() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled})
|
||||
|
||||
proxies, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
@@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() {
|
||||
// --- ExistsByHostPortAuth ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
@@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p-noauth",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
@@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||
// --- CountAccountsByProxyID ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
|
||||
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
@@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
|
||||
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"})
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err)
|
||||
@@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||
// --- GetAccountCountsForProxies ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
|
||||
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
@@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
|
||||
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p1",
|
||||
Status: model.StatusActive,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: base.Add(-1 * time.Hour),
|
||||
})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p2",
|
||||
Status: model.StatusActive,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: base,
|
||||
})
|
||||
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p3-inactive",
|
||||
Status: model.StatusDisabled,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
@@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
@@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
CreatedAt: time.Now().Add(-1 * time.Hour),
|
||||
UpdatedAt: time.Now().Add(-1 * time.Hour),
|
||||
})
|
||||
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
|
||||
Name: "p2",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
@@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists, "expected proxy to exist")
|
||||
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
|
||||
|
||||
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
|
||||
@@ -15,6 +15,16 @@ const (
|
||||
redeemRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
// redeemRateLimitKey generates the Redis key for redeem attempt rate limiting.
|
||||
func redeemRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// redeemLockKey generates the Redis key for redeem code locking.
|
||||
func redeemLockKey(code string) string {
|
||||
return redeemLockKeyPrefix + code
|
||||
}
|
||||
|
||||
type redeemCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -24,12 +34,16 @@ func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
|
||||
}
|
||||
|
||||
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
return c.rdb.Get(ctx, key).Int()
|
||||
key := redeemRateLimitKey(userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
key := redeemRateLimitKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
@@ -38,11 +52,11 @@ func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID in
|
||||
}
|
||||
|
||||
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
|
||||
key := redeemLockKeyPrefix + code
|
||||
key := redeemLockKey(code)
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
|
||||
key := redeemLockKeyPrefix + code
|
||||
key := redeemLockKey(code)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
@@ -25,9 +23,9 @@ func (s *RedeemCacheSuite) SetupTest() {
|
||||
|
||||
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
|
||||
missingUserID := int64(99999)
|
||||
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
|
||||
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
|
||||
require.NoError(s.T(), err, "expected nil error for missing rate-limit key")
|
||||
require.Equal(s.T(), 0, count, "expected zero count for missing key")
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
|
||||
|
||||
77
backend/internal/repository/redeem_cache_test.go
Normal file
77
backend/internal/repository/redeem_cache_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedeemRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "redeem:ratelimit:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "redeem:ratelimit:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "redeem:ratelimit:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "redeem:ratelimit:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := redeemRateLimitKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedeemLockKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_code",
|
||||
code: "ABC123",
|
||||
expected: "redeem:lock:ABC123",
|
||||
},
|
||||
{
|
||||
name: "empty_code",
|
||||
code: "",
|
||||
expected: "redeem:lock:",
|
||||
},
|
||||
{
|
||||
name: "code_with_special_chars",
|
||||
code: "CODE-2024:test",
|
||||
expected: "redeem:lock:CODE-2024:test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := redeemLockKey(tc.code)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,10 +4,8 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
|
||||
return &redeemCodeRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Create(code).Error
|
||||
func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||
m := redeemCodeModelFromService(code)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyRedeemCodeModelToService(code, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Create(&codes).Error
|
||||
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
|
||||
if len(codes) == 0 {
|
||||
return nil
|
||||
}
|
||||
models := make([]redeemCodeModel, 0, len(codes))
|
||||
for i := range codes {
|
||||
m := redeemCodeModelFromService(&codes[i])
|
||||
if m != nil {
|
||||
models = append(models, *m)
|
||||
}
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(&models).Error
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
var code model.RedeemCode
|
||||
err := r.db.WithContext(ctx).First(&code, id).Error
|
||||
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
var m redeemCodeModel
|
||||
err := r.db.WithContext(ctx).First(&m, id).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
|
||||
}
|
||||
return &code, nil
|
||||
return redeemCodeModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
|
||||
var redeemCode model.RedeemCode
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
|
||||
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
|
||||
var m redeemCodeModel
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
|
||||
}
|
||||
return &redeemCode, nil
|
||||
return redeemCodeModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
|
||||
return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
|
||||
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
var codes []model.RedeemCode
|
||||
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
var codes []redeemCodeModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.RedeemCode{})
|
||||
db := r.db.WithContext(ctx).Model(&redeemCodeModel{})
|
||||
|
||||
// Apply filters
|
||||
if codeType != "" {
|
||||
db = db.Where("type = ?", codeType)
|
||||
}
|
||||
@@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
outCodes := make([]service.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
|
||||
}
|
||||
|
||||
return codes, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return outCodes, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
|
||||
return r.db.WithContext(ctx).Save(code).Error
|
||||
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||
m := redeemCodeModelFromService(code)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyRedeemCodeModelToService(code, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
|
||||
now := time.Now()
|
||||
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
|
||||
Where("id = ? AND status = ?", id, model.StatusUnused).
|
||||
result := r.db.WithContext(ctx).Model(&redeemCodeModel{}).
|
||||
Where("id = ? AND status = ?", id, service.StatusUnused).
|
||||
Updates(map[string]any{
|
||||
"status": model.StatusUsed,
|
||||
"status": service.StatusUsed,
|
||||
"used_by": userID,
|
||||
"used_at": now,
|
||||
})
|
||||
@@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByUser returns all redeem codes used by a specific user
|
||||
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
|
||||
var codes []model.RedeemCode
|
||||
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
var codes []redeemCodeModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("used_by = ?", userID).
|
||||
Order("used_at DESC").
|
||||
Limit(limit).
|
||||
Find(&codes).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return codes, nil
|
||||
|
||||
outCodes := make([]service.RedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
|
||||
}
|
||||
return outCodes, nil
|
||||
}
|
||||
|
||||
type redeemCodeModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Code string `gorm:"uniqueIndex;size:32;not null"`
|
||||
Type string `gorm:"size:20;default:balance;not null"`
|
||||
Value float64 `gorm:"type:decimal(20,8);not null"`
|
||||
Status string `gorm:"size:20;default:unused;not null"`
|
||||
UsedBy *int64 `gorm:"index"`
|
||||
UsedAt *time.Time
|
||||
Notes string `gorm:"type:text"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
|
||||
GroupID *int64 `gorm:"index"`
|
||||
ValidityDays int `gorm:"default:30"`
|
||||
|
||||
User *userModel `gorm:"foreignKey:UsedBy"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
}
|
||||
|
||||
func (redeemCodeModel) TableName() string { return "redeem_codes" }
|
||||
|
||||
func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.RedeemCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
Type: m.Type,
|
||||
Value: m.Value,
|
||||
Status: m.Status,
|
||||
UsedBy: m.UsedBy,
|
||||
UsedAt: m.UsedAt,
|
||||
Notes: m.Notes,
|
||||
CreatedAt: m.CreatedAt,
|
||||
GroupID: m.GroupID,
|
||||
ValidityDays: m.ValidityDays,
|
||||
User: userModelToService(m.User),
|
||||
Group: groupModelToService(m.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return &redeemCodeModel{
|
||||
ID: r.ID,
|
||||
Code: r.Code,
|
||||
Type: r.Type,
|
||||
Value: r.Value,
|
||||
Status: r.Status,
|
||||
UsedBy: r.UsedBy,
|
||||
UsedAt: r.UsedAt,
|
||||
Notes: r.Notes,
|
||||
CreatedAt: r.CreatedAt,
|
||||
GroupID: r.GroupID,
|
||||
ValidityDays: r.ValidityDays,
|
||||
}
|
||||
}
|
||||
|
||||
func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) {
|
||||
if code == nil || m == nil {
|
||||
return
|
||||
}
|
||||
code.ID = m.ID
|
||||
code.CreatedAt = m.CreatedAt
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||
// --- Create / CreateBatch / GetByID / GetByCode ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||
code := &model.RedeemCode{
|
||||
code := &service.RedeemCode{
|
||||
Code: "TEST-CREATE",
|
||||
Type: model.RedeemTypeBalance,
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 100,
|
||||
Status: model.StatusUnused,
|
||||
Status: service.StatusUnused,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, code)
|
||||
@@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
|
||||
codes := []model.RedeemCode{
|
||||
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused},
|
||||
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused},
|
||||
codes := []service.RedeemCode{
|
||||
{Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
|
||||
{Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
|
||||
}
|
||||
|
||||
err := s.repo.CreateBatch(s.ctx, codes)
|
||||
@@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance})
|
||||
|
||||
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
@@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
|
||||
// --- Delete ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestDelete() {
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance})
|
||||
|
||||
err := s.repo.Delete(s.ctx, code.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() {
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestList() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance})
|
||||
|
||||
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription})
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "")
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type)
|
||||
s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "")
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Equal(model.StatusUsed, codes[0].Status)
|
||||
s.Require().Equal(service.StatusUsed, codes[0].Status)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance})
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
|
||||
s.Require().NoError(err)
|
||||
@@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
|
||||
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
|
||||
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "WITH-GROUP",
|
||||
Type: model.RedeemTypeSubscription,
|
||||
Type: service.RedeemTypeSubscription,
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
|
||||
@@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
|
||||
// --- Update ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10})
|
||||
code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10}))
|
||||
|
||||
code.Value = 50
|
||||
err := s.repo.Update(s.ctx, code)
|
||||
@@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||
// --- Use ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(model.StatusUsed, got.Status)
|
||||
s.Require().Equal(service.StatusUsed, got.Status)
|
||||
s.Require().NotNil(got.UsedBy)
|
||||
s.Require().Equal(user.ID, *got.UsedBy)
|
||||
s.Require().NotNil(got.UsedAt)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use first time")
|
||||
@@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"})
|
||||
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "expected error for already used code")
|
||||
@@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
// --- ListByUser ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create codes with explicit used_at for ordering
|
||||
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "USER-1",
|
||||
Type: model.RedeemTypeBalance,
|
||||
Status: model.StatusUsed,
|
||||
Type: service.RedeemTypeBalance,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c1).Update("used_at", base)
|
||||
|
||||
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "USER-2",
|
||||
Type: model.RedeemTypeBalance,
|
||||
Status: model.StatusUsed,
|
||||
Type: service.RedeemTypeBalance,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
|
||||
@@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"})
|
||||
|
||||
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "WITH-GRP",
|
||||
Type: model.RedeemTypeSubscription,
|
||||
Status: model.StatusUsed,
|
||||
Type: service.RedeemTypeSubscription,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
GroupID: &group.ID,
|
||||
})
|
||||
@@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"})
|
||||
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"})
|
||||
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
|
||||
Code: "DEF-LIM",
|
||||
Type: model.RedeemTypeBalance,
|
||||
Status: model.StatusUsed,
|
||||
Type: service.RedeemTypeBalance,
|
||||
Status: service.StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
})
|
||||
s.db.Model(c).Update("used_at", time.Now())
|
||||
@@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"})
|
||||
|
||||
codes := []model.RedeemCode{
|
||||
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()},
|
||||
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
|
||||
codes := []service.RedeemCode{
|
||||
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()},
|
||||
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
|
||||
}
|
||||
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
|
||||
|
||||
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code")
|
||||
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(list, 1)
|
||||
@@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
|
||||
// Use fixed time instead of time.Sleep for deterministic ordering
|
||||
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
|
||||
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
|
||||
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
|
||||
|
||||
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
|
||||
@@ -6,33 +6,27 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// SettingRepository 系统设置数据访问层
|
||||
type settingRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSettingRepository 创建系统设置仓库实例
|
||||
func NewSettingRepository(db *gorm.DB) service.SettingRepository {
|
||||
return &settingRepository{db: db}
|
||||
}
|
||||
|
||||
// Get 根据Key获取设置值
|
||||
func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
|
||||
var setting model.Setting
|
||||
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
|
||||
func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
var m settingModel
|
||||
err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
|
||||
}
|
||||
return &setting, nil
|
||||
return settingModelToService(&m), nil
|
||||
}
|
||||
|
||||
// GetValue 获取设置值字符串
|
||||
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
|
||||
setting, err := r.Get(ctx, key)
|
||||
if err != nil {
|
||||
@@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e
|
||||
return setting.Value, nil
|
||||
}
|
||||
|
||||
// Set 设置值(存在则更新,不存在则创建)
|
||||
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
|
||||
setting := &model.Setting{
|
||||
m := &settingModel{
|
||||
Key: key,
|
||||
Value: value,
|
||||
UpdatedAt: time.Now(),
|
||||
@@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error {
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
|
||||
}).Create(setting).Error
|
||||
}).Create(m).Error
|
||||
}
|
||||
|
||||
// GetMultiple 批量获取设置
|
||||
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
var settings []model.Setting
|
||||
var settings []settingModel
|
||||
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetMultiple 批量设置值
|
||||
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for key, value := range settings {
|
||||
setting := &model.Setting{
|
||||
m := &settingModel{
|
||||
Key: key,
|
||||
Value: value,
|
||||
UpdatedAt: time.Now(),
|
||||
@@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "key"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
|
||||
}).Create(setting).Error; err != nil {
|
||||
}).Create(m).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
|
||||
})
|
||||
}
|
||||
|
||||
// GetAll 获取所有设置
|
||||
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
var settings []model.Setting
|
||||
var settings []settingModel
|
||||
err := r.db.WithContext(ctx).Find(&settings).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Delete 删除设置
|
||||
func (r *settingRepository) Delete(ctx context.Context, key string) error {
|
||||
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
|
||||
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error
|
||||
}
|
||||
|
||||
type settingModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Key string `gorm:"uniqueIndex;size:100;not null"`
|
||||
Value string `gorm:"type:text;not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
}
|
||||
|
||||
func (settingModel) TableName() string { return "settings" }
|
||||
|
||||
func settingModelToService(m *settingModel) *service.Setting {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Setting{
|
||||
ID: m.ID,
|
||||
Key: m.Key,
|
||||
Value: m.Value,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
@@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
||||
TokenCount int64 `gorm:"column:token_count"`
|
||||
}
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
|
||||
@@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
||||
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
|
||||
m := usageLogModelFromService(log)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyUsageLogModelToService(log, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
|
||||
var log model.UsageLog
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||
var log usageLogModel
|
||||
err := r.db.WithContext(ctx).First(&log, id).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
|
||||
}
|
||||
return &log, nil
|
||||
return usageLogModelToService(&log), nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("user_id = ?", userID)
|
||||
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("user_id = ?", userID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
@@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("api_key_id = ?", apiKeyID)
|
||||
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("api_key_id = ?", apiKeyID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
@@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
// UserStats 用户使用统计
|
||||
@@ -125,7 +109,7 @@ type UserStats struct {
|
||||
|
||||
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
|
||||
var stats UserStats
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
@@ -145,51 +129,67 @@ type DashboardStats = usagestats.DashboardStats
|
||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
var stats DashboardStats
|
||||
today := timezone.Today()
|
||||
now := time.Now()
|
||||
|
||||
// 总用户数
|
||||
r.db.WithContext(ctx).Model(&model.User{}).Count(&stats.TotalUsers)
|
||||
// 合并用户统计查询
|
||||
var userStats struct {
|
||||
TotalUsers int64 `gorm:"column:total_users"`
|
||||
TodayNewUsers int64 `gorm:"column:today_new_users"`
|
||||
ActiveUsers int64 `gorm:"column:active_users"`
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Raw(`
|
||||
SELECT
|
||||
COUNT(*) as total_users,
|
||||
COUNT(CASE WHEN created_at >= ? THEN 1 END) as today_new_users,
|
||||
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= ?) as active_users
|
||||
FROM users
|
||||
`, today, today).Scan(&userStats).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalUsers = userStats.TotalUsers
|
||||
stats.TodayNewUsers = userStats.TodayNewUsers
|
||||
stats.ActiveUsers = userStats.ActiveUsers
|
||||
|
||||
// 今日新增用户数
|
||||
r.db.WithContext(ctx).Model(&model.User{}).
|
||||
Where("created_at >= ?", today).
|
||||
Count(&stats.TodayNewUsers)
|
||||
// 合并API Key统计查询
|
||||
var apiKeyStats struct {
|
||||
TotalApiKeys int64 `gorm:"column:total_api_keys"`
|
||||
ActiveApiKeys int64 `gorm:"column:active_api_keys"`
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Raw(`
|
||||
SELECT
|
||||
COUNT(*) as total_api_keys,
|
||||
COUNT(CASE WHEN status = ? THEN 1 END) as active_api_keys
|
||||
FROM api_keys
|
||||
`, service.StatusActive).Scan(&apiKeyStats).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalApiKeys = apiKeyStats.TotalApiKeys
|
||||
stats.ActiveApiKeys = apiKeyStats.ActiveApiKeys
|
||||
|
||||
// 今日活跃用户数 (今日有请求的用户)
|
||||
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
Distinct("user_id").
|
||||
Where("created_at >= ?", today).
|
||||
Count(&stats.ActiveUsers)
|
||||
|
||||
// 总 API Key 数
|
||||
r.db.WithContext(ctx).Model(&model.ApiKey{}).Count(&stats.TotalApiKeys)
|
||||
|
||||
// 活跃 API Key 数
|
||||
r.db.WithContext(ctx).Model(&model.ApiKey{}).
|
||||
Where("status = ?", model.StatusActive).
|
||||
Count(&stats.ActiveApiKeys)
|
||||
|
||||
// 总账户数
|
||||
r.db.WithContext(ctx).Model(&model.Account{}).Count(&stats.TotalAccounts)
|
||||
|
||||
// 正常账户数 (schedulable=true, status=active)
|
||||
r.db.WithContext(ctx).Model(&model.Account{}).
|
||||
Where("status = ? AND schedulable = ?", model.StatusActive, true).
|
||||
Count(&stats.NormalAccounts)
|
||||
|
||||
// 异常账户数 (status=error)
|
||||
r.db.WithContext(ctx).Model(&model.Account{}).
|
||||
Where("status = ?", model.StatusError).
|
||||
Count(&stats.ErrorAccounts)
|
||||
|
||||
// 限流账户数
|
||||
r.db.WithContext(ctx).Model(&model.Account{}).
|
||||
Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()).
|
||||
Count(&stats.RateLimitAccounts)
|
||||
|
||||
// 过载账户数
|
||||
r.db.WithContext(ctx).Model(&model.Account{}).
|
||||
Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()).
|
||||
Count(&stats.OverloadAccounts)
|
||||
// 合并账户统计查询
|
||||
var accountStats struct {
|
||||
TotalAccounts int64 `gorm:"column:total_accounts"`
|
||||
NormalAccounts int64 `gorm:"column:normal_accounts"`
|
||||
ErrorAccounts int64 `gorm:"column:error_accounts"`
|
||||
RateLimitAccounts int64 `gorm:"column:ratelimit_accounts"`
|
||||
OverloadAccounts int64 `gorm:"column:overload_accounts"`
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Raw(`
|
||||
SELECT
|
||||
COUNT(*) as total_accounts,
|
||||
COUNT(CASE WHEN status = ? AND schedulable = true THEN 1 END) as normal_accounts,
|
||||
COUNT(CASE WHEN status = ? THEN 1 END) as error_accounts,
|
||||
COUNT(CASE WHEN rate_limited_at IS NOT NULL AND rate_limit_reset_at > ? THEN 1 END) as ratelimit_accounts,
|
||||
COUNT(CASE WHEN overload_until IS NOT NULL AND overload_until > ? THEN 1 END) as overload_accounts
|
||||
FROM accounts
|
||||
`, service.StatusActive, service.StatusError, now, now).Scan(&accountStats).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalAccounts = accountStats.TotalAccounts
|
||||
stats.NormalAccounts = accountStats.NormalAccounts
|
||||
stats.ErrorAccounts = accountStats.ErrorAccounts
|
||||
stats.RateLimitAccounts = accountStats.RateLimitAccounts
|
||||
stats.OverloadAccounts = accountStats.OverloadAccounts
|
||||
|
||||
// 累计 Token 统计
|
||||
var totalStats struct {
|
||||
@@ -202,7 +202,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
TotalActualCost float64 `gorm:"column:total_actual_cost"`
|
||||
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
@@ -235,7 +235,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
TodayCost float64 `gorm:"column:today_cost"`
|
||||
TodayActualCost float64 `gorm:"column:today_actual_cost"`
|
||||
}
|
||||
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||||
@@ -263,11 +263,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("account_id = ?", accountID)
|
||||
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("account_id = ?", accountID)
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
@@ -277,57 +277,129 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
|
||||
Order("id DESC").
|
||||
Find(&logs).Error
|
||||
return logs, nil, err
|
||||
return usageLogModelsToService(logs), nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
// GetUserStatsAggregated returns aggregated usage statistics for a user using database-level aggregation
|
||||
func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
var stats struct {
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
TotalInputTokens int64 `gorm:"column:total_input_tokens"`
|
||||
TotalOutputTokens int64 `gorm:"column:total_output_tokens"`
|
||||
TotalCacheTokens int64 `gorm:"column:total_cache_tokens"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
TotalActualCost float64 `gorm:"column:total_actual_cost"`
|
||||
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
`).
|
||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
|
||||
Scan(&stats).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &usagestats.UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||||
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
var stats struct {
|
||||
TotalRequests int64 `gorm:"column:total_requests"`
|
||||
TotalInputTokens int64 `gorm:"column:total_input_tokens"`
|
||||
TotalOutputTokens int64 `gorm:"column:total_output_tokens"`
|
||||
TotalCacheTokens int64 `gorm:"column:total_cache_tokens"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
TotalActualCost float64 `gorm:"column:total_actual_cost"`
|
||||
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
`).
|
||||
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
|
||||
Scan(&stats).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &usagestats.UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
|
||||
Order("id DESC").
|
||||
Find(&logs).Error
|
||||
return logs, nil, err
|
||||
return usageLogModelsToService(logs), nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
Order("id DESC").
|
||||
Find(&logs).Error
|
||||
return logs, nil, err
|
||||
return usageLogModelsToService(logs), nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
|
||||
Order("id DESC").
|
||||
Find(&logs).Error
|
||||
return logs, nil, err
|
||||
return usageLogModelsToService(logs), nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
|
||||
return r.db.WithContext(ctx).Delete(&usageLogModel{}, id).Error
|
||||
}
|
||||
|
||||
// GetAccountTodayStats 获取账号今日统计
|
||||
@@ -340,7 +412,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
Cost float64 `gorm:"column:cost"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
@@ -368,7 +440,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
Cost float64 `gorm:"column:cost"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
@@ -499,12 +571,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
today := timezone.Today()
|
||||
|
||||
// API Key 统计
|
||||
r.db.WithContext(ctx).Model(&model.ApiKey{}).
|
||||
r.db.WithContext(ctx).Model(&apiKeyModel{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&stats.TotalApiKeys)
|
||||
|
||||
r.db.WithContext(ctx).Model(&model.ApiKey{}).
|
||||
Where("user_id = ? AND status = ?", userID, model.StatusActive).
|
||||
r.db.WithContext(ctx).Model(&apiKeyModel{}).
|
||||
Where("user_id = ? AND status = ?", userID, service.StatusActive).
|
||||
Count(&stats.ActiveApiKeys)
|
||||
|
||||
// 累计 Token 统计
|
||||
@@ -518,7 +590,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
TotalActualCost float64 `gorm:"column:total_actual_cost"`
|
||||
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
@@ -552,7 +624,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
TodayCost float64 `gorm:"column:today_cost"`
|
||||
TodayActualCost float64 `gorm:"column:today_actual_cost"`
|
||||
}
|
||||
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||||
@@ -591,7 +663,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
dateFormat = "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, ?) as date,
|
||||
COUNT(*) as requests,
|
||||
@@ -618,7 +690,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
@@ -644,11 +716,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
type UsageLogFilters = usagestats.UsageLogFilters
|
||||
|
||||
// ListWithFilters lists usage logs with optional filters (for admin)
|
||||
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []model.UsageLog
|
||||
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
var logs []usageLogModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{})
|
||||
db := r.db.WithContext(ctx).Model(&usageLogModel{})
|
||||
|
||||
// Apply filters
|
||||
if filters.UserID > 0 {
|
||||
@@ -657,6 +729,21 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
if filters.ApiKeyID > 0 {
|
||||
db = db.Where("api_key_id = ?", filters.ApiKeyID)
|
||||
}
|
||||
if filters.AccountID > 0 {
|
||||
db = db.Where("account_id = ?", filters.AccountID)
|
||||
}
|
||||
if filters.GroupID > 0 {
|
||||
db = db.Where("group_id = ?", filters.GroupID)
|
||||
}
|
||||
if filters.Model != "" {
|
||||
db = db.Where("model = ?", filters.Model)
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
db = db.Where("stream = ?", *filters.Stream)
|
||||
}
|
||||
if filters.BillingType != nil {
|
||||
db = db.Where("billing_type = ?", *filters.BillingType)
|
||||
}
|
||||
if filters.StartTime != nil {
|
||||
db = db.Where("created_at >= ?", *filters.StartTime)
|
||||
}
|
||||
@@ -668,24 +755,14 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Preload user and api_key for display
|
||||
if err := db.Preload("User").Preload("ApiKey").
|
||||
// Preload user, api_key, account, and group for display
|
||||
if err := db.Preload("User").Preload("ApiKey").Preload("Account").Preload("Group").
|
||||
Offset(params.Offset()).Limit(params.Limit()).
|
||||
Order("id DESC").Find(&logs).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return logs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
@@ -713,7 +790,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
UserID int64 `gorm:"column:user_id"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
}
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost").
|
||||
Where("user_id IN ?", userIDs).
|
||||
Group("user_id").
|
||||
@@ -733,7 +810,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
UserID int64 `gorm:"column:user_id"`
|
||||
TodayCost float64 `gorm:"column:today_cost"`
|
||||
}
|
||||
err = r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err = r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost").
|
||||
Where("user_id IN ? AND created_at >= ?", userIDs, today).
|
||||
Group("user_id").
|
||||
@@ -773,7 +850,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
|
||||
ApiKeyID int64 `gorm:"column:api_key_id"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
}
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost").
|
||||
Where("api_key_id IN ?", apiKeyIDs).
|
||||
Group("api_key_id").
|
||||
@@ -793,7 +870,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
|
||||
ApiKeyID int64 `gorm:"column:api_key_id"`
|
||||
TodayCost float64 `gorm:"column:today_cost"`
|
||||
}
|
||||
err = r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err = r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost").
|
||||
Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today).
|
||||
Group("api_key_id").
|
||||
@@ -822,7 +899,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
dateFormat = "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, ?) as date,
|
||||
COUNT(*) as requests,
|
||||
@@ -854,7 +931,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
|
||||
var results []ModelStat
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
@@ -896,7 +973,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
@@ -950,7 +1027,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
ActualCost float64 `gorm:"column:actual_cost"`
|
||||
}
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select(`
|
||||
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||||
COUNT(*) as requests,
|
||||
@@ -1011,7 +1088,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
var avgDuration struct {
|
||||
AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
|
||||
}
|
||||
r.db.WithContext(ctx).Model(&model.UsageLog{}).
|
||||
r.db.WithContext(ctx).Model(&usageLogModel{}).
|
||||
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
|
||||
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
|
||||
Scan(&avgDuration)
|
||||
@@ -1090,3 +1167,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
Models: models,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type usageLogModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
UserID int64 `gorm:"index;not null"`
|
||||
ApiKeyID int64 `gorm:"index;not null"`
|
||||
AccountID int64 `gorm:"index;not null"`
|
||||
RequestID string `gorm:"size:64"`
|
||||
Model string `gorm:"size:100;index;not null"`
|
||||
|
||||
GroupID *int64 `gorm:"index"`
|
||||
SubscriptionID *int64 `gorm:"index"`
|
||||
|
||||
InputTokens int `gorm:"default:0;not null"`
|
||||
OutputTokens int `gorm:"default:0;not null"`
|
||||
CacheCreationTokens int `gorm:"default:0;not null"`
|
||||
CacheReadTokens int `gorm:"default:0;not null"`
|
||||
|
||||
CacheCreation5mTokens int `gorm:"default:0;not null"`
|
||||
CacheCreation1hTokens int `gorm:"default:0;not null"`
|
||||
|
||||
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null"`
|
||||
|
||||
BillingType int8 `gorm:"type:smallint;default:0;not null"`
|
||||
Stream bool `gorm:"default:false;not null"`
|
||||
DurationMs *int
|
||||
FirstTokenMs *int
|
||||
|
||||
CreatedAt time.Time `gorm:"index;not null"`
|
||||
|
||||
User *userModel `gorm:"foreignKey:UserID"`
|
||||
ApiKey *apiKeyModel `gorm:"foreignKey:ApiKeyID"`
|
||||
Account *accountModel `gorm:"foreignKey:AccountID"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
Subscription *userSubscriptionModel `gorm:"foreignKey:SubscriptionID"`
|
||||
}
|
||||
|
||||
func (usageLogModel) TableName() string { return "usage_logs" }
|
||||
|
||||
func usageLogModelToService(m *usageLogModel) *service.UsageLog {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.UsageLog{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
ApiKeyID: m.ApiKeyID,
|
||||
AccountID: m.AccountID,
|
||||
RequestID: m.RequestID,
|
||||
Model: m.Model,
|
||||
GroupID: m.GroupID,
|
||||
SubscriptionID: m.SubscriptionID,
|
||||
InputTokens: m.InputTokens,
|
||||
OutputTokens: m.OutputTokens,
|
||||
CacheCreationTokens: m.CacheCreationTokens,
|
||||
CacheReadTokens: m.CacheReadTokens,
|
||||
CacheCreation5mTokens: m.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: m.CacheCreation1hTokens,
|
||||
InputCost: m.InputCost,
|
||||
OutputCost: m.OutputCost,
|
||||
CacheCreationCost: m.CacheCreationCost,
|
||||
CacheReadCost: m.CacheReadCost,
|
||||
TotalCost: m.TotalCost,
|
||||
ActualCost: m.ActualCost,
|
||||
RateMultiplier: m.RateMultiplier,
|
||||
BillingType: m.BillingType,
|
||||
Stream: m.Stream,
|
||||
DurationMs: m.DurationMs,
|
||||
FirstTokenMs: m.FirstTokenMs,
|
||||
CreatedAt: m.CreatedAt,
|
||||
User: userModelToService(m.User),
|
||||
ApiKey: apiKeyModelToService(m.ApiKey),
|
||||
Account: accountModelToService(m.Account),
|
||||
Group: groupModelToService(m.Group),
|
||||
Subscription: userSubscriptionModelToService(m.Subscription),
|
||||
}
|
||||
}
|
||||
|
||||
func usageLogModelsToService(models []usageLogModel) []service.UsageLog {
|
||||
out := make([]service.UsageLog, 0, len(models))
|
||||
for i := range models {
|
||||
if s := usageLogModelToService(&models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func usageLogModelFromService(log *service.UsageLog) *usageLogModel {
|
||||
if log == nil {
|
||||
return nil
|
||||
}
|
||||
return &usageLogModel{
|
||||
ID: log.ID,
|
||||
UserID: log.UserID,
|
||||
ApiKeyID: log.ApiKeyID,
|
||||
AccountID: log.AccountID,
|
||||
RequestID: log.RequestID,
|
||||
Model: log.Model,
|
||||
GroupID: log.GroupID,
|
||||
SubscriptionID: log.SubscriptionID,
|
||||
InputTokens: log.InputTokens,
|
||||
OutputTokens: log.OutputTokens,
|
||||
CacheCreationTokens: log.CacheCreationTokens,
|
||||
CacheReadTokens: log.CacheReadTokens,
|
||||
CacheCreation5mTokens: log.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: log.CacheCreation1hTokens,
|
||||
InputCost: log.InputCost,
|
||||
OutputCost: log.OutputCost,
|
||||
CacheCreationCost: log.CacheCreationCost,
|
||||
CacheReadCost: log.CacheReadCost,
|
||||
TotalCost: log.TotalCost,
|
||||
ActualCost: log.ActualCost,
|
||||
RateMultiplier: log.RateMultiplier,
|
||||
BillingType: log.BillingType,
|
||||
Stream: log.Stream,
|
||||
DurationMs: log.DurationMs,
|
||||
FirstTokenMs: log.FirstTokenMs,
|
||||
CreatedAt: log.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyUsageLogModelToService(log *service.UsageLog, m *usageLogModel) {
|
||||
if log == nil || m == nil {
|
||||
return
|
||||
}
|
||||
log.ID = m.ID
|
||||
log.CreatedAt = m.CreatedAt
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UsageLogRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog {
|
||||
log := &model.UsageLog{
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe
|
||||
// --- Create / GetByID ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"})
|
||||
|
||||
log := &model.UsageLog{
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
||||
// --- Delete ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"})
|
||||
|
||||
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() {
|
||||
// --- ListByUser ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
@@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() {
|
||||
// --- ListByApiKey ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||
@@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||
// --- ListByAccount ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||
// --- GetUserStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||
// --- ListWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
now := time.Now()
|
||||
todayStart := timezone.Today()
|
||||
|
||||
userToday := mustCreateUser(s.T(), s.db, &model.User{
|
||||
userToday := mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "today@example.com",
|
||||
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
|
||||
UpdatedAt: now,
|
||||
})
|
||||
userOld := mustCreateUser(s.T(), s.db, &model.User{
|
||||
userOld := mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "old@example.com",
|
||||
CreatedAt: todayStart.Add(-24 * time.Hour),
|
||||
UpdatedAt: todayStart.Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
|
||||
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
|
||||
accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
|
||||
|
||||
d1, d2, d3 := 100, 200, 300
|
||||
logToday := &model.UsageLog{
|
||||
logToday := &service.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
@@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
|
||||
|
||||
logOld := &model.UsageLog{
|
||||
logOld := &service.UsageLog{
|
||||
UserID: userOld.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
@@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
|
||||
|
||||
logPerf := &model.UsageLog{
|
||||
logPerf := &service.UsageLog{
|
||||
UserID: userToday.ID,
|
||||
ApiKeyID: apiKey1.ID,
|
||||
AccountID: accNormal.ID,
|
||||
@@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
// --- GetUserDashboardStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
// --- GetAccountTodayStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
// --- GetBatchUserUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"})
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"})
|
||||
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
@@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
// --- GetBatchApiKeyUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"})
|
||||
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
@@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
// --- GetGlobalStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time {
|
||||
// --- ListByUserAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||
// --- ListByApiKeyAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||
// --- ListByAccountAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||
// --- ListByModelAndTimeRange ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create logs with different models
|
||||
log1 := &model.UsageLog{
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
|
||||
log2 := &model.UsageLog{
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||
|
||||
log3 := &model.UsageLog{
|
||||
log3 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||
// --- GetAccountWindowStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"})
|
||||
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-10 * time.Minute)
|
||||
@@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||
// --- GetUserUsageTrendByUserID ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||
// --- GetUserModelStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create logs with different models
|
||||
log1 := &model.UsageLog{
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
|
||||
log2 := &model.UsageLog{
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||
// --- GetUsageTrendWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
// --- GetModelStatsWithFilters ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
log1 := &model.UsageLog{
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
|
||||
log2 := &model.UsageLog{
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
// --- GetAccountUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// Create logs on different days
|
||||
log1 := &model.UsageLog{
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||
|
||||
log2 := &model.UsageLog{
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||
startTime := base
|
||||
@@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||
// --- GetUserUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"})
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
|
||||
@@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||
// --- GetApiKeyUsageTrend ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
|
||||
@@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
|
||||
@@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||
// --- ListWithFilters (additional filter tests) ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
|
||||
@@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
@@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"})
|
||||
|
||||
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||
|
||||
@@ -2,12 +2,13 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository {
|
||||
return &userRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
|
||||
err := r.db.WithContext(ctx).Create(user).Error
|
||||
func (r *userRepository) Create(ctx context.Context, user *service.User) error {
|
||||
m := userModelFromService(user)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyUserModelToService(user, m)
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).First(&user, id).Error
|
||||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
var m userModel
|
||||
err := r.db.WithContext(ctx).First(&m, id).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return &user, nil
|
||||
return userModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
|
||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
var m userModel
|
||||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return &user, nil
|
||||
return userModelToService(&m), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
|
||||
err := r.db.WithContext(ctx).Save(user).Error
|
||||
func (r *userRepository) Update(ctx context.Context, user *service.User) error {
|
||||
m := userModelFromService(user)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyUserModelToService(user, m)
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
|
||||
return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists users with optional filtering by status, role, and search query
|
||||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
|
||||
var users []model.User
|
||||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
|
||||
var users []userModel
|
||||
var total int64
|
||||
|
||||
db := r.db.WithContext(ctx).Model(&model.User{})
|
||||
db := r.db.WithContext(ctx).Model(&userModel{})
|
||||
|
||||
// Apply filters
|
||||
if status != "" {
|
||||
@@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
// Batch load subscriptions for all users (avoid N+1)
|
||||
if len(users) > 0 {
|
||||
userIDs := make([]int64, len(users))
|
||||
userMap := make(map[int64]*model.User, len(users))
|
||||
userMap := make(map[int64]*service.User, len(users))
|
||||
outUsers := make([]service.User, 0, len(users))
|
||||
for i := range users {
|
||||
userIDs[i] = users[i].ID
|
||||
userMap[users[i].ID] = &users[i]
|
||||
u := userModelToService(&users[i])
|
||||
outUsers = append(outUsers, *u)
|
||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||
}
|
||||
|
||||
// Query active subscriptions with groups in one query
|
||||
var subscriptions []model.UserSubscription
|
||||
var subscriptions []userSubscriptionModel
|
||||
if err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive).
|
||||
Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
// Associate subscriptions with users
|
||||
for i := range subscriptions {
|
||||
if user, ok := userMap[subscriptions[i].UserID]; ok {
|
||||
user.Subscriptions = append(user.Subscriptions, subscriptions[i])
|
||||
user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
|
||||
}
|
||||
}
|
||||
|
||||
return outUsers, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
outUsers := make([]service.User, 0, len(users))
|
||||
for i := range users {
|
||||
outUsers = append(outUsers, *userModelToService(&users[i]))
|
||||
}
|
||||
|
||||
return users, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return outUsers, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
|
||||
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
|
||||
Update("balance", gorm.Expr("balance + ?", amount)).Error
|
||||
}
|
||||
|
||||
// DeductBalance 扣减用户余额,仅当余额充足时执行
|
||||
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
result := r.db.WithContext(ctx).Model(&model.User{}).
|
||||
result := r.db.WithContext(ctx).Model(&userModel{}).
|
||||
Where("id = ? AND balance >= ?", id, amount).
|
||||
Update("balance", gorm.Expr("balance - ?", amount))
|
||||
if result.Error != nil {
|
||||
@@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
|
||||
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
|
||||
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
|
||||
err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
|
||||
// 使用 PostgreSQL 的 array_remove 函数
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.User{}).
|
||||
result := r.db.WithContext(ctx).Model(&userModel{}).
|
||||
Where("? = ANY(allowed_groups)", groupID).
|
||||
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
|
||||
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
|
||||
var user model.User
|
||||
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||||
var m userModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
|
||||
Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
|
||||
Order("id ASC").
|
||||
First(&user).Error
|
||||
First(&m).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
return &user, nil
|
||||
return userModelToService(&m), nil
|
||||
}
|
||||
|
||||
type userModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
Email string `gorm:"uniqueIndex;size:255;not null"`
|
||||
Username string `gorm:"size:100;default:''"`
|
||||
Wechat string `gorm:"size:100;default:''"`
|
||||
Notes string `gorm:"type:text;default:''"`
|
||||
PasswordHash string `gorm:"size:255;not null"`
|
||||
Role string `gorm:"size:20;default:user;not null"`
|
||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
|
||||
Concurrency int `gorm:"default:5;not null"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (userModel) TableName() string { return "users" }
|
||||
|
||||
func userModelToService(m *userModel) *service.User {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: m.ID,
|
||||
Email: m.Email,
|
||||
Username: m.Username,
|
||||
Wechat: m.Wechat,
|
||||
Notes: m.Notes,
|
||||
PasswordHash: m.PasswordHash,
|
||||
Role: m.Role,
|
||||
Balance: m.Balance,
|
||||
Concurrency: m.Concurrency,
|
||||
Status: m.Status,
|
||||
AllowedGroups: []int64(m.AllowedGroups),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func userModelFromService(u *service.User) *userModel {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &userModel{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Wechat: u.Wechat,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
AllowedGroups: pq.Int64Array(u.AllowedGroups),
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyUserModelToService(dst *service.User, src *userModel) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
@@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) {
|
||||
// --- Create / GetByID / GetByEmail / Update / Delete ---
|
||||
|
||||
func (s *UserRepoSuite) TestCreate() {
|
||||
user := &model.User{
|
||||
Email: "create@test.com",
|
||||
Username: "testuser",
|
||||
Role: model.RoleUser,
|
||||
Status: model.StatusActive,
|
||||
user := &service.User{
|
||||
Email: "create@test.com",
|
||||
Username: "testuser",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, user)
|
||||
@@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"})
|
||||
|
||||
got, err := s.repo.GetByEmail(s.ctx, user.Email)
|
||||
s.Require().NoError(err, "GetByEmail")
|
||||
@@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"})
|
||||
user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"}))
|
||||
|
||||
user.Username = "updated"
|
||||
err := s.repo.Update(s.ctx, user)
|
||||
@@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
@@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() {
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *UserRepoSuite) TestList() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"})
|
||||
|
||||
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
@@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Status() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "")
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal(model.StatusActive, users[0].Status)
|
||||
s.Require().Equal(service.StatusActive, users[0].Status)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Role() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "")
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal(model.RoleAdmin, users[0].Role)
|
||||
s.Require().Equal(service.RoleAdmin, users[0].Role)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Search() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
|
||||
s.Require().NoError(err)
|
||||
@@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
|
||||
s.Require().NoError(err)
|
||||
@@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"})
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
|
||||
s.Require().NoError(err)
|
||||
@@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"})
|
||||
|
||||
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
})
|
||||
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Wechat: "wx_a",
|
||||
Role: model.RoleUser,
|
||||
Status: model.StatusActive,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
target := mustCreateUser(s.T(), s.db, &model.User{
|
||||
target := mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Wechat: "wx_b",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "c@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusDisabled,
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@")
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
@@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
// --- Balance operations ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
|
||||
s.Require().NoError(err, "UpdateBalance")
|
||||
@@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
|
||||
s.Require().NoError(err, "UpdateBalance with negative")
|
||||
@@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
|
||||
s.Require().NoError(err, "DeductBalance")
|
||||
@@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
|
||||
s.Require().Error(err, "expected error for insufficient balance")
|
||||
@@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "DeductBalance exact amount")
|
||||
@@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
|
||||
// --- Concurrency ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
|
||||
s.Require().NoError(err, "UpdateConcurrency")
|
||||
@@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
|
||||
s.Require().NoError(err, "UpdateConcurrency negative")
|
||||
@@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
|
||||
// --- ExistsByEmail ---
|
||||
|
||||
func (s *UserRepoSuite) TestExistsByEmail() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||
mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
|
||||
|
||||
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
|
||||
s.Require().NoError(err, "ExistsByEmail")
|
||||
@@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() {
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
|
||||
groupID := int64(42)
|
||||
userA := mustCreateUser(s.T(), s.db, &model.User{
|
||||
userA := mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "a1@example.com",
|
||||
AllowedGroups: pq.Int64Array{groupID, 7},
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "a2@example.com",
|
||||
AllowedGroups: pq.Int64Array{7},
|
||||
})
|
||||
@@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "nomatch@test.com",
|
||||
AllowedGroups: pq.Int64Array{1, 2, 3},
|
||||
})
|
||||
@@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
|
||||
// --- GetFirstAdmin ---
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin() {
|
||||
admin1 := mustCreateUser(s.T(), s.db, &model.User{
|
||||
admin1 := mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "admin1@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "admin2@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
@@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "user@example.com",
|
||||
Role: model.RoleUser,
|
||||
Status: model.StatusActive,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
_, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
@@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
|
||||
mustCreateUser(s.T(), s.db, &model.User{
|
||||
mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "disabled@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusDisabled,
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{
|
||||
activeAdmin := mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "active@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
@@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Wechat: "wx_a",
|
||||
Role: model.RoleUser,
|
||||
Status: model.StatusActive,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Wechat: "wx_b",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusActive,
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
_ = mustCreateUser(s.T(), s.db, &model.User{
|
||||
_ = mustCreateUser(s.T(), s.db, &userModel{
|
||||
Email: "c@example.com",
|
||||
Role: model.RoleAdmin,
|
||||
Status: model.StatusDisabled,
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
@@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
|
||||
|
||||
params := pagination.PaginationParams{Page: 1, PageSize: 10}
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@")
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
|
||||
@@ -4,111 +4,113 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserSubscriptionRepository 用户订阅仓库
|
||||
type userSubscriptionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserSubscriptionRepository 创建用户订阅仓库
|
||||
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
|
||||
return &userSubscriptionRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建订阅
|
||||
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
|
||||
err := r.db.WithContext(ctx).Create(sub).Error
|
||||
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
m := userSubscriptionModelFromService(sub)
|
||||
err := r.db.WithContext(ctx).Create(m).Error
|
||||
if err == nil {
|
||||
applyUserSubscriptionModelToService(sub, m)
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
var m userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("User").
|
||||
Preload("Group").
|
||||
Preload("AssignedByUser").
|
||||
First(&sub, id).Error
|
||||
First(&m, id).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return &sub, nil
|
||||
return userSubscriptionModelToService(&m), nil
|
||||
}
|
||||
|
||||
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
|
||||
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
var m userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
First(&sub).Error
|
||||
First(&m).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return &sub, nil
|
||||
return userSubscriptionModelToService(&m), nil
|
||||
}
|
||||
|
||||
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
|
||||
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
var sub model.UserSubscription
|
||||
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
var m userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
|
||||
userID, groupID, model.SubscriptionStatusActive, time.Now()).
|
||||
First(&sub).Error
|
||||
userID, groupID, service.SubscriptionStatusActive, time.Now()).
|
||||
First(&m).Error
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
return &sub, nil
|
||||
return userSubscriptionModelToService(&m), nil
|
||||
}
|
||||
|
||||
// Update 更新订阅
|
||||
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
|
||||
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
sub.UpdatedAt = time.Now()
|
||||
return r.db.WithContext(ctx).Save(sub).Error
|
||||
m := userSubscriptionModelFromService(sub)
|
||||
err := r.db.WithContext(ctx).Save(m).Error
|
||||
if err == nil {
|
||||
applyUserSubscriptionModelToService(sub, m)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除订阅
|
||||
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
|
||||
return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
|
||||
}
|
||||
|
||||
// ListByUserID 获取用户的所有订阅
|
||||
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
var subs []userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Find(&subs).Error
|
||||
return subs, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userSubscriptionModelsToService(subs), nil
|
||||
}
|
||||
|
||||
// ListActiveByUserID 获取用户的所有有效订阅
|
||||
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
var subs []userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Group").
|
||||
Where("user_id = ? AND status = ? AND expires_at > ?",
|
||||
userID, model.SubscriptionStatusActive, time.Now()).
|
||||
userID, service.SubscriptionStatusActive, time.Now()).
|
||||
Order("created_at DESC").
|
||||
Find(&subs).Error
|
||||
return subs, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userSubscriptionModelsToService(subs), nil
|
||||
}
|
||||
|
||||
// ListByGroupID 获取分组的所有订阅(分页)
|
||||
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []userSubscriptionModel
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID)
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return subs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []model.UserSubscription
|
||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
var subs []userSubscriptionModel
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&model.UserSubscription{})
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
|
||||
if userID != nil {
|
||||
query = query.Where("user_id = ?", *userID)
|
||||
}
|
||||
@@ -170,22 +160,87 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
return subs, &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}, nil
|
||||
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", subscriptionID).
|
||||
Updates(map[string]any{
|
||||
"expires_at": newExpiresAt,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", subscriptionID).
|
||||
Updates(map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", subscriptionID).
|
||||
Updates(map[string]any{
|
||||
"notes": notes,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"daily_window_start": start,
|
||||
"weekly_window_start": start,
|
||||
"monthly_window_start": start,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": 0,
|
||||
"daily_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"weekly_usage_usd": 0,
|
||||
"weekly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"monthly_usage_usd": 0,
|
||||
"monthly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// IncrementUsage 增加使用量
|
||||
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
|
||||
@@ -195,131 +250,150 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ResetDailyUsage 重置日使用量
|
||||
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"daily_usage_usd": 0,
|
||||
"daily_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ResetWeeklyUsage 重置周使用量
|
||||
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"weekly_usage_usd": 0,
|
||||
"weekly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ResetMonthlyUsage 重置月使用量
|
||||
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"monthly_usage_usd": 0,
|
||||
"monthly_window_start": newWindowStart,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ActivateWindows 激活所有窗口(首次使用时)
|
||||
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"daily_window_start": activateTime,
|
||||
"weekly_window_start": activateTime,
|
||||
"monthly_window_start": activateTime,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateStatus 更新订阅状态
|
||||
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ExtendExpiry 延长订阅过期时间
|
||||
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"expires_at": newExpiresAt,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateNotes 更新订阅备注
|
||||
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]any{
|
||||
"notes": notes,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ListExpired 获取所有已过期但状态仍为active的订阅
|
||||
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
|
||||
var subs []model.UserSubscription
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||
Find(&subs).Error
|
||||
return subs, err
|
||||
}
|
||||
|
||||
// BatchUpdateExpiredStatus 批量更新过期订阅状态
|
||||
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
|
||||
result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
|
||||
Updates(map[string]any{
|
||||
"status": model.SubscriptionStatusExpired,
|
||||
"status": service.SubscriptionStatusExpired,
|
||||
"updated_at": time.Now(),
|
||||
})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
|
||||
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
Where("user_id = ? AND group_id = ?", userID, groupID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
// Extra repository helpers (currently used only by integration tests).
|
||||
|
||||
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
|
||||
var subs []userSubscriptionModel
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
|
||||
Find(&subs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userSubscriptionModelsToService(subs), nil
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的订阅数量
|
||||
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("group_id = ?", groupID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountActiveByGroupID 获取分组的有效订阅数量
|
||||
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
|
||||
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
|
||||
Where("group_id = ? AND status = ? AND expires_at > ?",
|
||||
groupID, model.SubscriptionStatusActive, time.Now()).
|
||||
groupID, service.SubscriptionStatusActive, time.Now()).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除分组相关的所有订阅记录
|
||||
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
|
||||
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
type userSubscriptionModel struct {
|
||||
ID int64 `gorm:"primaryKey"`
|
||||
UserID int64 `gorm:"index;not null"`
|
||||
GroupID int64 `gorm:"index;not null"`
|
||||
|
||||
StartsAt time.Time `gorm:"not null"`
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
Status string `gorm:"size:20;default:active;not null"`
|
||||
|
||||
DailyWindowStart *time.Time
|
||||
WeeklyWindowStart *time.Time
|
||||
MonthlyWindowStart *time.Time
|
||||
|
||||
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
|
||||
|
||||
AssignedBy *int64 `gorm:"index"`
|
||||
AssignedAt time.Time `gorm:"not null"`
|
||||
Notes string `gorm:"type:text"`
|
||||
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
|
||||
User *userModel `gorm:"foreignKey:UserID"`
|
||||
Group *groupModel `gorm:"foreignKey:GroupID"`
|
||||
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
|
||||
}
|
||||
|
||||
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
|
||||
|
||||
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.UserSubscription{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
GroupID: m.GroupID,
|
||||
StartsAt: m.StartsAt,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
Status: m.Status,
|
||||
DailyWindowStart: m.DailyWindowStart,
|
||||
WeeklyWindowStart: m.WeeklyWindowStart,
|
||||
MonthlyWindowStart: m.MonthlyWindowStart,
|
||||
DailyUsageUSD: m.DailyUsageUSD,
|
||||
WeeklyUsageUSD: m.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: m.MonthlyUsageUSD,
|
||||
AssignedBy: m.AssignedBy,
|
||||
AssignedAt: m.AssignedAt,
|
||||
Notes: m.Notes,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
User: userModelToService(m.User),
|
||||
Group: groupModelToService(m.Group),
|
||||
AssignedByUser: userModelToService(m.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
|
||||
out := make([]service.UserSubscription, 0, len(models))
|
||||
for i := range models {
|
||||
if s := userSubscriptionModelToService(&models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &userSubscriptionModel{
|
||||
ID: s.ID,
|
||||
UserID: s.UserID,
|
||||
GroupID: s.GroupID,
|
||||
StartsAt: s.StartsAt,
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
Status: s.Status,
|
||||
DailyWindowStart: s.DailyWindowStart,
|
||||
WeeklyWindowStart: s.WeeklyWindowStart,
|
||||
MonthlyWindowStart: s.MonthlyWindowStart,
|
||||
DailyUsageUSD: s.DailyUsageUSD,
|
||||
WeeklyUsageUSD: s.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: s.MonthlyUsageUSD,
|
||||
AssignedBy: s.AssignedBy,
|
||||
AssignedAt: s.AssignedAt,
|
||||
Notes: s.Notes,
|
||||
CreatedAt: s.CreatedAt,
|
||||
UpdatedAt: s.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
|
||||
if sub == nil || m == nil {
|
||||
return
|
||||
}
|
||||
sub.ID = m.ID
|
||||
sub.CreatedAt = m.CreatedAt
|
||||
sub.UpdatedAt = m.UpdatedAt
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) {
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCreate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"})
|
||||
|
||||
sub := &model.UserSubscription{
|
||||
sub := &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
@@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
|
||||
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
|
||||
admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
|
||||
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
AssignedBy: &admin.ID,
|
||||
})
|
||||
@@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdate() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"})
|
||||
sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
}))
|
||||
|
||||
sub.Notes = "updated notes"
|
||||
err := s.repo.Update(s.ctx, sub)
|
||||
@@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDelete() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() {
|
||||
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"})
|
||||
|
||||
// Create active subscription (future expiry)
|
||||
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(2 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"})
|
||||
|
||||
// Create expired subscription (past expiry but active status)
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor
|
||||
// --- ListByUserID / ListActiveByUserID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "ListActiveByUserID")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status)
|
||||
s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status)
|
||||
}
|
||||
|
||||
// --- ListByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"})
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
|
||||
// --- List with filters ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"})
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"})
|
||||
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"})
|
||||
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired)
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||
}
|
||||
|
||||
// --- Usage tracking ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
DailyUsageUSD: 10.0,
|
||||
WeeklyUsageUSD: 20.0,
|
||||
@@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
WeeklyUsageUSD: 15.0,
|
||||
MonthlyUsageUSD: 30.0,
|
||||
@@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
MonthlyUsageUSD: 100.0,
|
||||
})
|
||||
@@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
|
||||
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
|
||||
s.Require().NoError(err, "UpdateStatus")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(model.SubscriptionStatusExpired, got.Status)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, got.Status)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"})
|
||||
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
|
||||
// --- ListExpired / BatchUpdateExpiredStatus ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListExpired() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"})
|
||||
|
||||
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
|
||||
s.Require().Equal(int64(1), affected)
|
||||
|
||||
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status)
|
||||
s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status)
|
||||
|
||||
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status)
|
||||
}
|
||||
|
||||
// --- ExistsByUserIDAndGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
|
||||
// --- CountByGroupID / CountActiveByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
|
||||
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"})
|
||||
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user1.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user2.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
|
||||
})
|
||||
|
||||
@@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
|
||||
// --- DeleteByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"})
|
||||
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
})
|
||||
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusExpired,
|
||||
Status: service.SubscriptionStatusExpired,
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
|
||||
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"})
|
||||
user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"})
|
||||
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"})
|
||||
|
||||
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(2 * time.Hour),
|
||||
})
|
||||
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||
})
|
||||
|
||||
@@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||
s.Require().NoError(err, "GetByID expired")
|
||||
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewIdentityCache,
|
||||
NewRedeemCache,
|
||||
NewUpdateCache,
|
||||
NewGeminiTokenCache,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
@@ -35,4 +36,6 @@ var ProviderSet = wire.NewSet(
|
||||
NewClaudeOAuthClient,
|
||||
NewHTTPUpstream,
|
||||
NewOpenAIOAuthClient,
|
||||
NewGeminiOAuthClient,
|
||||
NewGeminiCliCodeAssistClient,
|
||||
)
|
||||
|
||||
1140
backend/internal/server/api_contract_test.go
Normal file
1140
backend/internal/server/api_contract_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
@@ -25,6 +26,8 @@ func ProvideRouter(
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -33,7 +36,7 @@ func ProvideRouter(
|
||||
r := gin.New()
|
||||
r.Use(middleware2.Recovery())
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -84,7 +83,11 @@ func validateAdminApiKey(
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), admin)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: admin.ID,
|
||||
Concurrency: admin.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), admin.Role)
|
||||
c.Set("auth_method", "admin_api_key")
|
||||
return true
|
||||
}
|
||||
@@ -121,12 +124,16 @@ func validateJWTForAdmin(
|
||||
}
|
||||
|
||||
// 检查管理员权限
|
||||
if user.Role != model.RoleAdmin {
|
||||
if !user.IsAdmin() {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), user)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: user.ID,
|
||||
Concurrency: user.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), user.Role)
|
||||
c.Set("auth_method", "jwt")
|
||||
|
||||
return true
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user