mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-05 07:52:13 +08:00
Compare commits
91 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3d1891ccd | ||
|
|
4d8f2db924 | ||
|
|
6599b366dc | ||
|
|
ba16ace697 | ||
|
|
cadca752c4 | ||
|
|
edf215e6fd | ||
|
|
e12dd079fd | ||
|
|
04a509d45e | ||
|
|
269a659200 | ||
|
|
2c31bf46b5 | ||
|
|
8f6639f825 | ||
|
|
fc17d9d7df | ||
|
|
ab092e88a8 | ||
|
|
56a1e29cdd | ||
|
|
0059a232a6 | ||
|
|
45676fdc8d | ||
|
|
e32c5f534f | ||
|
|
426d691c95 | ||
|
|
e9a4c8ab97 | ||
|
|
a55cfebd09 | ||
|
|
34cc02f8c7 | ||
|
|
624d9fddb7 | ||
|
|
47fbe43324 | ||
|
|
1245f07a2d | ||
|
|
839975b0cf | ||
|
|
8c1233393f | ||
|
|
9cdb0568cc | ||
|
|
74e05b83ea | ||
|
|
4ded9e7d49 | ||
|
|
716272a1e2 | ||
|
|
9cc8352593 | ||
|
|
43a1031e38 | ||
|
|
a5547b2f30 | ||
|
|
b0aa23540b | ||
|
|
ffaa6c4a17 | ||
|
|
fbf72f0ec4 | ||
|
|
909b8a8f9c | ||
|
|
4a0fe3b143 | ||
|
|
a1292fac81 | ||
|
|
7f98be4f91 | ||
|
|
fd73b8875d | ||
|
|
f9ab1daa3c | ||
|
|
d27b847442 | ||
|
|
dac6bc2228 | ||
|
|
4bd3dbf2ce | ||
|
|
226df1c23a | ||
|
|
2665230a09 | ||
|
|
4f0c2b794c | ||
|
|
e756064c19 | ||
|
|
17dfb0af01 | ||
|
|
ff74f517df | ||
|
|
477a9a180f | ||
|
|
da48df06d2 | ||
|
|
39fad63ccf | ||
|
|
5602d02b1b | ||
|
|
81989eed1c | ||
|
|
192efb84a0 | ||
|
|
8672347f93 | ||
|
|
5e5d4a513b | ||
|
|
88b6358472 | ||
|
|
dd8d5e2c42 | ||
|
|
d91e2328fb | ||
|
|
2a16735495 | ||
|
|
292f25f9ca | ||
|
|
c92e37775a | ||
|
|
f6ed3d1456 | ||
|
|
84686753e8 | ||
|
|
91f01309da | ||
|
|
57a1fc9d33 | ||
|
|
c95a864975 | ||
|
|
7a83db6180 | ||
|
|
a8513da7ff | ||
|
|
53534d3956 | ||
|
|
cc07a0e295 | ||
|
|
e7bc62500b | ||
|
|
c8fb9ef3a5 | ||
|
|
eb5e6214bc | ||
|
|
568d6ee10e | ||
|
|
6aef1af76e | ||
|
|
a54852e129 | ||
|
|
668118def1 | ||
|
|
73e6b160f8 | ||
|
|
6fec141de6 | ||
|
|
31cde6c555 | ||
|
|
b1a980f344 | ||
|
|
00d9fbd220 | ||
|
|
4f4c9679bf | ||
|
|
3dab71729d | ||
|
|
2f6f758670 | ||
|
|
090c8981dd | ||
|
|
fbb572948d |
4
.github/workflows/backend-ci.yml
vendored
4
.github/workflows/backend-ci.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
cache: true
|
cache: true
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.25.5'
|
go version | grep -q 'go1.25.6'
|
||||||
- name: Unit tests
|
- name: Unit tests
|
||||||
working-directory: backend
|
working-directory: backend
|
||||||
run: make test-unit
|
run: make test-unit
|
||||||
@@ -38,7 +38,7 @@ jobs:
|
|||||||
cache: true
|
cache: true
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.25.5'
|
go version | grep -q 'go1.25.6'
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v9
|
uses: golangci/golangci-lint-action@v9
|
||||||
with:
|
with:
|
||||||
|
|||||||
5
.github/workflows/release.yml
vendored
5
.github/workflows/release.yml
vendored
@@ -115,7 +115,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.25.5'
|
go version | grep -q 'go1.25.6'
|
||||||
|
|
||||||
# Docker setup for GoReleaser
|
# Docker setup for GoReleaser
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
@@ -222,8 +222,9 @@ jobs:
|
|||||||
REPO="${{ github.repository }}"
|
REPO="${{ github.repository }}"
|
||||||
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
|
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
|
||||||
|
|
||||||
# 获取 tag message 内容
|
# 获取 tag message 内容并转义 Markdown 特殊字符
|
||||||
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
|
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
|
||||||
|
TAG_MESSAGE=$(echo "$TAG_MESSAGE" | sed 's/\([_*`\[]\)/\\\1/g')
|
||||||
|
|
||||||
# 限制消息长度(Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
|
# 限制消息长度(Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
|
||||||
if [ ${#TAG_MESSAGE} -gt 3500 ]; then
|
if [ ${#TAG_MESSAGE} -gt 3500 ]; then
|
||||||
|
|||||||
2
.github/workflows/security-scan.yml
vendored
2
.github/workflows/security-scan.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
cache-dependency-path: backend/go.sum
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.25.5'
|
go version | grep -q 'go1.25.6'
|
||||||
- name: Run govulncheck
|
- name: Run govulncheck
|
||||||
working-directory: backend
|
working-directory: backend
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
ARG NODE_IMAGE=node:24-alpine
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
ARG GOLANG_IMAGE=golang:1.25.5-alpine
|
ARG GOLANG_IMAGE=golang:1.25.6-alpine
|
||||||
ARG ALPINE_IMAGE=alpine:3.20
|
ARG ALPINE_IMAGE=alpine:3.20
|
||||||
ARG GOPROXY=https://goproxy.cn,direct
|
ARG GOPROXY=https://goproxy.cn,direct
|
||||||
ARG GOSUMDB=sum.golang.google.cn
|
ARG GOSUMDB=sum.golang.google.cn
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
|
|||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
|
|
||||||
Try Sub2API online: **https://v2.pincc.ai/**
|
Try Sub2API online: **https://demo.sub2api.org/**
|
||||||
|
|
||||||
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
0.1.46
|
0.1.61
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ func provideCleanup(
|
|||||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
|
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||||
usageCleanup *service.UsageCleanupService,
|
usageCleanup *service.UsageCleanupService,
|
||||||
pricing *service.PricingService,
|
pricing *service.PricingService,
|
||||||
emailQueue *service.EmailQueueService,
|
emailQueue *service.EmailQueueService,
|
||||||
@@ -138,6 +139,10 @@ func provideCleanup(
|
|||||||
accountExpiry.Stop()
|
accountExpiry.Stop()
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"SubscriptionExpiryService", func() error {
|
||||||
|
subscriptionExpiry.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"PricingService", func() error {
|
{"PricingService", func() error {
|
||||||
pricing.Stop()
|
pricing.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -63,7 +63,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
totpCache := repository.NewTotpCache(redisClient)
|
||||||
|
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||||
|
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
|
||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
@@ -84,7 +90,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
}
|
}
|
||||||
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||||
accountRepository := repository.NewAccountRepository(client, db)
|
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||||
|
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
@@ -105,21 +112,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||||
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
|
||||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
|
||||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||||
@@ -128,7 +136,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||||
promoHandler := admin.NewPromoHandler(promoService)
|
promoHandler := admin.NewPromoHandler(promoService)
|
||||||
opsRepository := repository.NewOpsRepository(db)
|
opsRepository := repository.NewOpsRepository(db)
|
||||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||||
@@ -137,7 +144,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityCache := repository.NewIdentityCache(redisClient)
|
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||||
@@ -165,7 +171,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
totpHandler := handler.NewTotpHandler(totpService)
|
||||||
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
|
||||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||||
@@ -178,7 +185,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -211,6 +219,7 @@ func provideCleanup(
|
|||||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
|
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||||
usageCleanup *service.UsageCleanupService,
|
usageCleanup *service.UsageCleanupService,
|
||||||
pricing *service.PricingService,
|
pricing *service.PricingService,
|
||||||
emailQueue *service.EmailQueueService,
|
emailQueue *service.EmailQueueService,
|
||||||
@@ -278,6 +287,10 @@ func provideCleanup(
|
|||||||
accountExpiry.Stop()
|
accountExpiry.Stop()
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"SubscriptionExpiryService", func() error {
|
||||||
|
subscriptionExpiry.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"PricingService", func() error {
|
{"PricingService", func() error {
|
||||||
pricing.Stop()
|
pricing.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -610,6 +610,9 @@ var (
|
|||||||
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
|
||||||
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
|
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
|
||||||
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
||||||
|
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||||
|
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
||||||
|
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
||||||
}
|
}
|
||||||
// UsersTable holds the schema information for the "users" table.
|
// UsersTable holds the schema information for the "users" table.
|
||||||
UsersTable = &schema.Table{
|
UsersTable = &schema.Table{
|
||||||
|
|||||||
@@ -14360,6 +14360,9 @@ type UserMutation struct {
|
|||||||
status *string
|
status *string
|
||||||
username *string
|
username *string
|
||||||
notes *string
|
notes *string
|
||||||
|
totp_secret_encrypted *string
|
||||||
|
totp_enabled *bool
|
||||||
|
totp_enabled_at *time.Time
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@@ -14937,6 +14940,140 @@ func (m *UserMutation) ResetNotes() {
|
|||||||
m.notes = nil
|
m.notes = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||||
|
func (m *UserMutation) SetTotpSecretEncrypted(s string) {
|
||||||
|
m.totp_secret_encrypted = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncrypted returns the value of the "totp_secret_encrypted" field in the mutation.
|
||||||
|
func (m *UserMutation) TotpSecretEncrypted() (r string, exists bool) {
|
||||||
|
v := m.totp_secret_encrypted
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldTotpSecretEncrypted returns the old "totp_secret_encrypted" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldTotpSecretEncrypted(ctx context.Context) (v *string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldTotpSecretEncrypted is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldTotpSecretEncrypted requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldTotpSecretEncrypted: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.TotpSecretEncrypted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||||
|
func (m *UserMutation) ClearTotpSecretEncrypted() {
|
||||||
|
m.totp_secret_encrypted = nil
|
||||||
|
m.clearedFields[user.FieldTotpSecretEncrypted] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedCleared returns if the "totp_secret_encrypted" field was cleared in this mutation.
|
||||||
|
func (m *UserMutation) TotpSecretEncryptedCleared() bool {
|
||||||
|
_, ok := m.clearedFields[user.FieldTotpSecretEncrypted]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetTotpSecretEncrypted resets all changes to the "totp_secret_encrypted" field.
|
||||||
|
func (m *UserMutation) ResetTotpSecretEncrypted() {
|
||||||
|
m.totp_secret_encrypted = nil
|
||||||
|
delete(m.clearedFields, user.FieldTotpSecretEncrypted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabled sets the "totp_enabled" field.
|
||||||
|
func (m *UserMutation) SetTotpEnabled(b bool) {
|
||||||
|
m.totp_enabled = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabled returns the value of the "totp_enabled" field in the mutation.
|
||||||
|
func (m *UserMutation) TotpEnabled() (r bool, exists bool) {
|
||||||
|
v := m.totp_enabled
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldTotpEnabled returns the old "totp_enabled" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldTotpEnabled(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldTotpEnabled is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldTotpEnabled requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldTotpEnabled: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.TotpEnabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetTotpEnabled resets all changes to the "totp_enabled" field.
|
||||||
|
func (m *UserMutation) ResetTotpEnabled() {
|
||||||
|
m.totp_enabled = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||||
|
func (m *UserMutation) SetTotpEnabledAt(t time.Time) {
|
||||||
|
m.totp_enabled_at = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAt returns the value of the "totp_enabled_at" field in the mutation.
|
||||||
|
func (m *UserMutation) TotpEnabledAt() (r time.Time, exists bool) {
|
||||||
|
v := m.totp_enabled_at
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldTotpEnabledAt returns the old "totp_enabled_at" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldTotpEnabledAt(ctx context.Context) (v *time.Time, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldTotpEnabledAt is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldTotpEnabledAt requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldTotpEnabledAt: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.TotpEnabledAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||||
|
func (m *UserMutation) ClearTotpEnabledAt() {
|
||||||
|
m.totp_enabled_at = nil
|
||||||
|
m.clearedFields[user.FieldTotpEnabledAt] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtCleared returns if the "totp_enabled_at" field was cleared in this mutation.
|
||||||
|
func (m *UserMutation) TotpEnabledAtCleared() bool {
|
||||||
|
_, ok := m.clearedFields[user.FieldTotpEnabledAt]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetTotpEnabledAt resets all changes to the "totp_enabled_at" field.
|
||||||
|
func (m *UserMutation) ResetTotpEnabledAt() {
|
||||||
|
m.totp_enabled_at = nil
|
||||||
|
delete(m.clearedFields, user.FieldTotpEnabledAt)
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -15403,7 +15540,7 @@ func (m *UserMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UserMutation) Fields() []string {
|
func (m *UserMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 11)
|
fields := make([]string, 0, 14)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, user.FieldCreatedAt)
|
fields = append(fields, user.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -15437,6 +15574,15 @@ func (m *UserMutation) Fields() []string {
|
|||||||
if m.notes != nil {
|
if m.notes != nil {
|
||||||
fields = append(fields, user.FieldNotes)
|
fields = append(fields, user.FieldNotes)
|
||||||
}
|
}
|
||||||
|
if m.totp_secret_encrypted != nil {
|
||||||
|
fields = append(fields, user.FieldTotpSecretEncrypted)
|
||||||
|
}
|
||||||
|
if m.totp_enabled != nil {
|
||||||
|
fields = append(fields, user.FieldTotpEnabled)
|
||||||
|
}
|
||||||
|
if m.totp_enabled_at != nil {
|
||||||
|
fields = append(fields, user.FieldTotpEnabledAt)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -15467,6 +15613,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.Username()
|
return m.Username()
|
||||||
case user.FieldNotes:
|
case user.FieldNotes:
|
||||||
return m.Notes()
|
return m.Notes()
|
||||||
|
case user.FieldTotpSecretEncrypted:
|
||||||
|
return m.TotpSecretEncrypted()
|
||||||
|
case user.FieldTotpEnabled:
|
||||||
|
return m.TotpEnabled()
|
||||||
|
case user.FieldTotpEnabledAt:
|
||||||
|
return m.TotpEnabledAt()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -15498,6 +15650,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
|
|||||||
return m.OldUsername(ctx)
|
return m.OldUsername(ctx)
|
||||||
case user.FieldNotes:
|
case user.FieldNotes:
|
||||||
return m.OldNotes(ctx)
|
return m.OldNotes(ctx)
|
||||||
|
case user.FieldTotpSecretEncrypted:
|
||||||
|
return m.OldTotpSecretEncrypted(ctx)
|
||||||
|
case user.FieldTotpEnabled:
|
||||||
|
return m.OldTotpEnabled(ctx)
|
||||||
|
case user.FieldTotpEnabledAt:
|
||||||
|
return m.OldTotpEnabledAt(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown User field %s", name)
|
return nil, fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
@@ -15584,6 +15742,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetNotes(v)
|
m.SetNotes(v)
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldTotpSecretEncrypted:
|
||||||
|
v, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetTotpSecretEncrypted(v)
|
||||||
|
return nil
|
||||||
|
case user.FieldTotpEnabled:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetTotpEnabled(v)
|
||||||
|
return nil
|
||||||
|
case user.FieldTotpEnabledAt:
|
||||||
|
v, ok := value.(time.Time)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetTotpEnabledAt(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User field %s", name)
|
return fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
@@ -15644,6 +15823,12 @@ func (m *UserMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(user.FieldDeletedAt) {
|
if m.FieldCleared(user.FieldDeletedAt) {
|
||||||
fields = append(fields, user.FieldDeletedAt)
|
fields = append(fields, user.FieldDeletedAt)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(user.FieldTotpSecretEncrypted) {
|
||||||
|
fields = append(fields, user.FieldTotpSecretEncrypted)
|
||||||
|
}
|
||||||
|
if m.FieldCleared(user.FieldTotpEnabledAt) {
|
||||||
|
fields = append(fields, user.FieldTotpEnabledAt)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -15661,6 +15846,12 @@ func (m *UserMutation) ClearField(name string) error {
|
|||||||
case user.FieldDeletedAt:
|
case user.FieldDeletedAt:
|
||||||
m.ClearDeletedAt()
|
m.ClearDeletedAt()
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldTotpSecretEncrypted:
|
||||||
|
m.ClearTotpSecretEncrypted()
|
||||||
|
return nil
|
||||||
|
case user.FieldTotpEnabledAt:
|
||||||
|
m.ClearTotpEnabledAt()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User nullable field %s", name)
|
return fmt.Errorf("unknown User nullable field %s", name)
|
||||||
}
|
}
|
||||||
@@ -15702,6 +15893,15 @@ func (m *UserMutation) ResetField(name string) error {
|
|||||||
case user.FieldNotes:
|
case user.FieldNotes:
|
||||||
m.ResetNotes()
|
m.ResetNotes()
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldTotpSecretEncrypted:
|
||||||
|
m.ResetTotpSecretEncrypted()
|
||||||
|
return nil
|
||||||
|
case user.FieldTotpEnabled:
|
||||||
|
m.ResetTotpEnabled()
|
||||||
|
return nil
|
||||||
|
case user.FieldTotpEnabledAt:
|
||||||
|
m.ResetTotpEnabledAt()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User field %s", name)
|
return fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -736,6 +736,10 @@ func init() {
|
|||||||
userDescNotes := userFields[7].Descriptor()
|
userDescNotes := userFields[7].Descriptor()
|
||||||
// user.DefaultNotes holds the default value on creation for the notes field.
|
// user.DefaultNotes holds the default value on creation for the notes field.
|
||||||
user.DefaultNotes = userDescNotes.Default.(string)
|
user.DefaultNotes = userDescNotes.Default.(string)
|
||||||
|
// userDescTotpEnabled is the schema descriptor for totp_enabled field.
|
||||||
|
userDescTotpEnabled := userFields[9].Descriptor()
|
||||||
|
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
||||||
|
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
||||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||||
_ = userallowedgroupFields
|
_ = userallowedgroupFields
|
||||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||||
|
|||||||
@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
|
|||||||
field.String("notes").
|
field.String("notes").
|
||||||
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||||
Default(""),
|
Default(""),
|
||||||
|
|
||||||
|
// TOTP 双因素认证字段
|
||||||
|
field.String("totp_secret_encrypted").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||||
|
Optional().
|
||||||
|
Nillable(),
|
||||||
|
field.Bool("totp_enabled").
|
||||||
|
Default(false),
|
||||||
|
field.Time("totp_enabled_at").
|
||||||
|
Optional().
|
||||||
|
Nillable(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,12 @@ type User struct {
|
|||||||
Username string `json:"username,omitempty"`
|
Username string `json:"username,omitempty"`
|
||||||
// Notes holds the value of the "notes" field.
|
// Notes holds the value of the "notes" field.
|
||||||
Notes string `json:"notes,omitempty"`
|
Notes string `json:"notes,omitempty"`
|
||||||
|
// TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
|
||||||
|
TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
|
||||||
|
// TotpEnabled holds the value of the "totp_enabled" field.
|
||||||
|
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
||||||
|
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
||||||
|
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the UserQuery when eager-loading is set.
|
// The values are being populated by the UserQuery when eager-loading is set.
|
||||||
Edges UserEdges `json:"edges"`
|
Edges UserEdges `json:"edges"`
|
||||||
@@ -156,13 +162,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
|||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
|
case user.FieldTotpEnabled:
|
||||||
|
values[i] = new(sql.NullBool)
|
||||||
case user.FieldBalance:
|
case user.FieldBalance:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case user.FieldID, user.FieldConcurrency:
|
case user.FieldID, user.FieldConcurrency:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
|
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
|
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
default:
|
default:
|
||||||
values[i] = new(sql.UnknownType)
|
values[i] = new(sql.UnknownType)
|
||||||
@@ -252,6 +260,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.Notes = value.String
|
_m.Notes = value.String
|
||||||
}
|
}
|
||||||
|
case user.FieldTotpSecretEncrypted:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.TotpSecretEncrypted = new(string)
|
||||||
|
*_m.TotpSecretEncrypted = value.String
|
||||||
|
}
|
||||||
|
case user.FieldTotpEnabled:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.TotpEnabled = value.Bool
|
||||||
|
}
|
||||||
|
case user.FieldTotpEnabledAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.TotpEnabledAt = new(time.Time)
|
||||||
|
*_m.TotpEnabledAt = value.Time
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -367,6 +395,19 @@ func (_m *User) String() string {
|
|||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("notes=")
|
builder.WriteString("notes=")
|
||||||
builder.WriteString(_m.Notes)
|
builder.WriteString(_m.Notes)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.TotpSecretEncrypted; v != nil {
|
||||||
|
builder.WriteString("totp_secret_encrypted=")
|
||||||
|
builder.WriteString(*v)
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("totp_enabled=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.TotpEnabledAt; v != nil {
|
||||||
|
builder.WriteString("totp_enabled_at=")
|
||||||
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
|
}
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,12 @@ const (
|
|||||||
FieldUsername = "username"
|
FieldUsername = "username"
|
||||||
// FieldNotes holds the string denoting the notes field in the database.
|
// FieldNotes holds the string denoting the notes field in the database.
|
||||||
FieldNotes = "notes"
|
FieldNotes = "notes"
|
||||||
|
// FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
|
||||||
|
FieldTotpSecretEncrypted = "totp_secret_encrypted"
|
||||||
|
// FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
|
||||||
|
FieldTotpEnabled = "totp_enabled"
|
||||||
|
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
||||||
|
FieldTotpEnabledAt = "totp_enabled_at"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -134,6 +140,9 @@ var Columns = []string{
|
|||||||
FieldStatus,
|
FieldStatus,
|
||||||
FieldUsername,
|
FieldUsername,
|
||||||
FieldNotes,
|
FieldNotes,
|
||||||
|
FieldTotpSecretEncrypted,
|
||||||
|
FieldTotpEnabled,
|
||||||
|
FieldTotpEnabledAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -188,6 +197,8 @@ var (
|
|||||||
UsernameValidator func(string) error
|
UsernameValidator func(string) error
|
||||||
// DefaultNotes holds the default value on creation for the "notes" field.
|
// DefaultNotes holds the default value on creation for the "notes" field.
|
||||||
DefaultNotes string
|
DefaultNotes string
|
||||||
|
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
||||||
|
DefaultTotpEnabled bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the User queries.
|
// OrderOption defines the ordering options for the User queries.
|
||||||
@@ -253,6 +264,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldNotes, opts...).ToFunc()
|
return sql.OrderByField(FieldNotes, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
|
||||||
|
func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByTotpEnabled orders the results by the totp_enabled field.
|
||||||
|
func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByTotpEnabledAt orders the results by the totp_enabled_at field.
|
||||||
|
func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
|
|||||||
return predicate.User(sql.FieldEQ(FieldNotes, v))
|
return predicate.User(sql.FieldEQ(FieldNotes, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
|
||||||
|
func TotpSecretEncrypted(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
|
||||||
|
func TotpEnabled(v bool) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
|
||||||
|
func TotpEnabledAt(v time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.User {
|
func CreatedAtEQ(v time.Time) predicate.User {
|
||||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
|
|||||||
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
|
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedEQ(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedNEQ(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedIn(vs ...string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedGT(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedGTE(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedLT(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedLTE(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedContains(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedHasPrefix(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedHasSuffix(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedIsNil() predicate.User {
|
||||||
|
return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedNotNil() predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedEqualFold(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
|
||||||
|
func TotpSecretEncryptedContainsFold(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
|
||||||
|
func TotpEnabledEQ(v bool) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
|
||||||
|
func TotpEnabledNEQ(v bool) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtEQ(v time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtNEQ(v time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtIn(vs ...time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtGT(v time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtGTE(v time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtLT(v time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtLTE(v time.Time) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtIsNil() predicate.User {
|
||||||
|
return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
|
||||||
|
func TotpEnabledAtNotNil() predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.User {
|
func HasAPIKeys() predicate.User {
|
||||||
return predicate.User(func(s *sql.Selector) {
|
return predicate.User(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -167,6 +167,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||||
|
func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
|
||||||
|
_c.mutation.SetTotpSecretEncrypted(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetTotpSecretEncrypted(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabled sets the "totp_enabled" field.
|
||||||
|
func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
|
||||||
|
_c.mutation.SetTotpEnabled(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetTotpEnabled(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||||
|
func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
|
||||||
|
_c.mutation.SetTotpEnabledAt(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetTotpEnabledAt(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -362,6 +404,10 @@ func (_c *UserCreate) defaults() error {
|
|||||||
v := user.DefaultNotes
|
v := user.DefaultNotes
|
||||||
_c.mutation.SetNotes(v)
|
_c.mutation.SetNotes(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||||
|
v := user.DefaultTotpEnabled
|
||||||
|
_c.mutation.SetTotpEnabled(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -422,6 +468,9 @@ func (_c *UserCreate) check() error {
|
|||||||
if _, ok := _c.mutation.Notes(); !ok {
|
if _, ok := _c.mutation.Notes(); !ok {
|
||||||
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
|
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||||
|
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -493,6 +542,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||||
_node.Notes = value
|
_node.Notes = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||||
|
_node.TotpSecretEncrypted = &value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.TotpEnabled(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||||
|
_node.TotpEnabled = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.TotpEnabledAt(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||||
|
_node.TotpEnabledAt = &value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -815,6 +876,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||||
|
func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
|
||||||
|
u.Set(user.FieldTotpSecretEncrypted, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldTotpSecretEncrypted)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||||
|
func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
|
||||||
|
u.SetNull(user.FieldTotpSecretEncrypted)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabled sets the "totp_enabled" field.
|
||||||
|
func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
|
||||||
|
u.Set(user.FieldTotpEnabled, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldTotpEnabled)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||||
|
func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
|
||||||
|
u.Set(user.FieldTotpEnabledAt, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldTotpEnabledAt)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||||
|
func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
|
||||||
|
u.SetNull(user.FieldTotpEnabledAt)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -1021,6 +1130,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||||
|
func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetTotpSecretEncrypted(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateTotpSecretEncrypted()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||||
|
func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.ClearTotpSecretEncrypted()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabled sets the "totp_enabled" field.
|
||||||
|
func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetTotpEnabled(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateTotpEnabled()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||||
|
func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetTotpEnabledAt(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateTotpEnabledAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||||
|
func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.ClearTotpEnabledAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -1393,6 +1558,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||||
|
func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetTotpSecretEncrypted(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateTotpSecretEncrypted()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||||
|
func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.ClearTotpSecretEncrypted()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabled sets the "totp_enabled" field.
|
||||||
|
func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetTotpEnabled(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateTotpEnabled()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||||
|
func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetTotpEnabledAt(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateTotpEnabledAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||||
|
func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.ClearTotpEnabledAt()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -187,6 +187,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||||
|
func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
|
||||||
|
_u.mutation.SetTotpSecretEncrypted(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTotpSecretEncrypted(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||||
|
func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
|
||||||
|
_u.mutation.ClearTotpSecretEncrypted()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabled sets the "totp_enabled" field.
|
||||||
|
func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
|
||||||
|
_u.mutation.SetTotpEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTotpEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||||
|
func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
|
||||||
|
_u.mutation.SetTotpEnabledAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTotpEnabledAt(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||||
|
func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
|
||||||
|
_u.mutation.ClearTotpEnabledAt()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -603,6 +657,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.Notes(); ok {
|
if value, ok := _u.mutation.Notes(); ok {
|
||||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.TotpSecretEncryptedCleared() {
|
||||||
|
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.TotpEnabled(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.TotpEnabledAt(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.TotpEnabledAtCleared() {
|
||||||
|
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1147,6 +1216,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
|
||||||
|
func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
|
||||||
|
_u.mutation.SetTotpSecretEncrypted(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTotpSecretEncrypted(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
|
||||||
|
func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
|
||||||
|
_u.mutation.ClearTotpSecretEncrypted()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabled sets the "totp_enabled" field.
|
||||||
|
func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
|
||||||
|
_u.mutation.SetTotpEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTotpEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotpEnabledAt sets the "totp_enabled_at" field.
|
||||||
|
func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
|
||||||
|
_u.mutation.SetTotpEnabledAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTotpEnabledAt(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
|
||||||
|
func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
|
||||||
|
_u.mutation.ClearTotpEnabledAt()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1593,6 +1716,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
|||||||
if value, ok := _u.mutation.Notes(); ok {
|
if value, ok := _u.mutation.Notes(); ok {
|
||||||
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
_spec.SetField(user.FieldNotes, field.TypeString, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.TotpSecretEncryptedCleared() {
|
||||||
|
_spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.TotpEnabled(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.TotpEnabledAt(); ok {
|
||||||
|
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.TotpEnabledAtCleared() {
|
||||||
|
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
module github.com/Wei-Shaw/sub2api
|
module github.com/Wei-Shaw/sub2api
|
||||||
|
|
||||||
go 1.25.5
|
go 1.25.6
|
||||||
|
|
||||||
require (
|
require (
|
||||||
entgo.io/ent v0.14.5
|
entgo.io/ent v0.14.5
|
||||||
@@ -37,6 +37,7 @@ require (
|
|||||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||||
|
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
@@ -106,6 +107,7 @@ require (
|
|||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
|
github.com/pquerna/otp v1.5.0 // indirect
|
||||||
github.com/quic-go/qpack v0.6.0 // indirect
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
github.com/quic-go/quic-go v0.57.1 // indirect
|
github.com/quic-go/quic-go v0.57.1 // indirect
|
||||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
|||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||||
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
|
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
|
||||||
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
|
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
|
||||||
|
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||||
|
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||||
@@ -217,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
|
|||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||||
|
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||||
|
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type Config struct {
|
|||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
Ops OpsConfig `mapstructure:"ops"`
|
Ops OpsConfig `mapstructure:"ops"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
|
Totp TotpConfig `mapstructure:"totp"`
|
||||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
@@ -466,6 +467,16 @@ type JWTConfig struct {
|
|||||||
ExpireHour int `mapstructure:"expire_hour"`
|
ExpireHour int `mapstructure:"expire_hour"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TotpConfig TOTP 双因素认证配置
|
||||||
|
type TotpConfig struct {
|
||||||
|
// EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码)
|
||||||
|
// 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
|
||||||
|
EncryptionKey string `mapstructure:"encryption_key"`
|
||||||
|
// EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
|
||||||
|
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
|
||||||
|
EncryptionKeyConfigured bool `mapstructure:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
type TurnstileConfig struct {
|
type TurnstileConfig struct {
|
||||||
Required bool `mapstructure:"required"`
|
Required bool `mapstructure:"required"`
|
||||||
}
|
}
|
||||||
@@ -626,6 +637,20 @@ func Load() (*Config, error) {
|
|||||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||||
|
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||||
|
if cfg.Totp.EncryptionKey == "" {
|
||||||
|
key, err := generateJWTSecret(32) // Reuse the same random generation function
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate totp encryption key error: %w", err)
|
||||||
|
}
|
||||||
|
cfg.Totp.EncryptionKey = key
|
||||||
|
cfg.Totp.EncryptionKeyConfigured = false
|
||||||
|
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||||
|
} else {
|
||||||
|
cfg.Totp.EncryptionKeyConfigured = true
|
||||||
|
}
|
||||||
|
|
||||||
if err := cfg.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
return nil, fmt.Errorf("validate config error: %w", err)
|
return nil, fmt.Errorf("validate config error: %w", err)
|
||||||
}
|
}
|
||||||
@@ -756,6 +781,9 @@ func setDefaults() {
|
|||||||
viper.SetDefault("jwt.secret", "")
|
viper.SetDefault("jwt.secret", "")
|
||||||
viper.SetDefault("jwt.expire_hour", 24)
|
viper.SetDefault("jwt.expire_hour", 24)
|
||||||
|
|
||||||
|
// TOTP
|
||||||
|
viper.SetDefault("totp.encryption_key", "")
|
||||||
|
|
||||||
// Default
|
// Default
|
||||||
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
|
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
|
||||||
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
|
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ type AccountHandler struct {
|
|||||||
concurrencyService *service.ConcurrencyService
|
concurrencyService *service.ConcurrencyService
|
||||||
crsSyncService *service.CRSSyncService
|
crsSyncService *service.CRSSyncService
|
||||||
sessionLimitCache service.SessionLimitCache
|
sessionLimitCache service.SessionLimitCache
|
||||||
|
tokenCacheInvalidator service.TokenCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountHandler creates a new admin account handler
|
// NewAccountHandler creates a new admin account handler
|
||||||
@@ -60,6 +61,7 @@ func NewAccountHandler(
|
|||||||
concurrencyService *service.ConcurrencyService,
|
concurrencyService *service.ConcurrencyService,
|
||||||
crsSyncService *service.CRSSyncService,
|
crsSyncService *service.CRSSyncService,
|
||||||
sessionLimitCache service.SessionLimitCache,
|
sessionLimitCache service.SessionLimitCache,
|
||||||
|
tokenCacheInvalidator service.TokenCacheInvalidator,
|
||||||
) *AccountHandler {
|
) *AccountHandler {
|
||||||
return &AccountHandler{
|
return &AccountHandler{
|
||||||
adminService: adminService,
|
adminService: adminService,
|
||||||
@@ -73,6 +75,7 @@ func NewAccountHandler(
|
|||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
crsSyncService: crsSyncService,
|
crsSyncService: crsSyncService,
|
||||||
sessionLimitCache: sessionLimitCache,
|
sessionLimitCache: sessionLimitCache,
|
||||||
|
tokenCacheInvalidator: tokenCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,6 +176,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||||
windowCostAccountIDs := make([]int64, 0)
|
windowCostAccountIDs := make([]int64, 0)
|
||||||
sessionLimitAccountIDs := make([]int64, 0)
|
sessionLimitAccountIDs := make([]int64, 0)
|
||||||
|
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
acc := &accounts[i]
|
acc := &accounts[i]
|
||||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||||
@@ -181,6 +185,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if acc.GetMaxSessions() > 0 {
|
if acc.GetMaxSessions() > 0 {
|
||||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||||
|
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -189,9 +194,9 @@ func (h *AccountHandler) List(c *gin.Context) {
|
|||||||
var windowCosts map[int64]float64
|
var windowCosts map[int64]float64
|
||||||
var activeSessions map[int64]int
|
var activeSessions map[int64]int
|
||||||
|
|
||||||
// 获取活跃会话数(批量查询)
|
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
|
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||||
if activeSessions == nil {
|
if activeSessions == nil {
|
||||||
activeSessions = make(map[int64]int)
|
activeSessions = make(map[int64]int)
|
||||||
}
|
}
|
||||||
@@ -542,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
|
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||||
|
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||||
|
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||||
|
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||||
|
newCredentials["project_id"] = oldProjectID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||||
|
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||||
if tokenInfo.ProjectIDMissing {
|
if tokenInfo.ProjectIDMissing {
|
||||||
// 先更新凭证
|
// 先更新凭证(token 本身刷新成功了)
|
||||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||||
Credentials: newCredentials,
|
Credentials: newCredentials,
|
||||||
})
|
})
|
||||||
@@ -552,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 标记账户为 error
|
// 不标记为 error,只返回警告信息
|
||||||
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
|
|
||||||
response.InternalError(c, "Failed to set account error: "+setErr.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"message": "Token refreshed but project_id is missing, account marked as error",
|
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||||
"warning": "missing_project_id",
|
"warning": "missing_project_id_temporary",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -606,6 +616,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||||
|
if h.tokenCacheInvalidator != nil {
|
||||||
|
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||||
|
// 缓存失效失败只记录日志,不影响主流程
|
||||||
|
_ = c.Error(invalidateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -655,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token(触发刷新或从 DB 读取)
|
||||||
|
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||||
|
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||||
|
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||||
|
// 缓存失效失败只记录日志,不影响主流程
|
||||||
|
_ = c.Error(invalidateErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
response.Success(c, dto.AccountFromService(account))
|
response.Success(c, dto.AccountFromService(account))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -94,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
outGroups := make([]dto.Group, 0, len(groups))
|
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||||
for i := range groups {
|
for i := range groups {
|
||||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||||
}
|
}
|
||||||
response.Paginated(c, outGroups, total, page, pageSize)
|
response.Paginated(c, outGroups, total, page, pageSize)
|
||||||
}
|
}
|
||||||
@@ -120,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
outGroups := make([]dto.Group, 0, len(groups))
|
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||||
for i := range groups {
|
for i := range groups {
|
||||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||||
}
|
}
|
||||||
response.Success(c, outGroups)
|
response.Success(c, outGroups)
|
||||||
}
|
}
|
||||||
@@ -142,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.GroupFromService(group))
|
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create handles creating a new group
|
// Create handles creating a new group
|
||||||
@@ -177,7 +177,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.GroupFromService(group))
|
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update handles updating a group
|
// Update handles updating a group
|
||||||
@@ -219,7 +219,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.GroupFromService(group))
|
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete handles deleting a group
|
// Delete handles deleting a group
|
||||||
|
|||||||
@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.RedeemCode, 0, len(codes))
|
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||||
for i := range codes {
|
for i := range codes {
|
||||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||||
}
|
}
|
||||||
response.Paginated(c, out, total, page, pageSize)
|
response.Paginated(c, out, total, page, pageSize)
|
||||||
}
|
}
|
||||||
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.RedeemCodeFromService(code))
|
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate handles generating new redeem codes
|
// Generate handles generating new redeem codes
|
||||||
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.RedeemCode, 0, len(codes))
|
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||||
for i := range codes {
|
for i := range codes {
|
||||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||||
}
|
}
|
||||||
response.Success(c, out)
|
response.Success(c, out)
|
||||||
}
|
}
|
||||||
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.RedeemCodeFromService(code))
|
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats handles getting redeem code statistics
|
// GetStats handles getting redeem code statistics
|
||||||
|
|||||||
@@ -47,6 +47,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
response.Success(c, dto.SystemSettings{
|
response.Success(c, dto.SystemSettings{
|
||||||
RegistrationEnabled: settings.RegistrationEnabled,
|
RegistrationEnabled: settings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
|
TotpEnabled: settings.TotpEnabled,
|
||||||
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
SMTPHost: settings.SMTPHost,
|
SMTPHost: settings.SMTPHost,
|
||||||
SMTPPort: settings.SMTPPort,
|
SMTPPort: settings.SMTPPort,
|
||||||
SMTPUsername: settings.SMTPUsername,
|
SMTPUsername: settings.SMTPUsername,
|
||||||
@@ -68,6 +72,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
ContactInfo: settings.ContactInfo,
|
ContactInfo: settings.ContactInfo,
|
||||||
DocURL: settings.DocURL,
|
DocURL: settings.DocURL,
|
||||||
HomeContent: settings.HomeContent,
|
HomeContent: settings.HomeContent,
|
||||||
|
HideCcsImportButton: settings.HideCcsImportButton,
|
||||||
|
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||||
|
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||||
DefaultConcurrency: settings.DefaultConcurrency,
|
DefaultConcurrency: settings.DefaultConcurrency,
|
||||||
DefaultBalance: settings.DefaultBalance,
|
DefaultBalance: settings.DefaultBalance,
|
||||||
EnableModelFallback: settings.EnableModelFallback,
|
EnableModelFallback: settings.EnableModelFallback,
|
||||||
@@ -87,8 +94,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
// UpdateSettingsRequest 更新设置请求
|
// UpdateSettingsRequest 更新设置请求
|
||||||
type UpdateSettingsRequest struct {
|
type UpdateSettingsRequest struct {
|
||||||
// 注册设置
|
// 注册设置
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SMTPHost string `json:"smtp_host"`
|
SMTPHost string `json:"smtp_host"`
|
||||||
@@ -111,13 +121,16 @@ type UpdateSettingsRequest struct {
|
|||||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||||
|
|
||||||
// OEM设置
|
// OEM设置
|
||||||
SiteName string `json:"site_name"`
|
SiteName string `json:"site_name"`
|
||||||
SiteLogo string `json:"site_logo"`
|
SiteLogo string `json:"site_logo"`
|
||||||
SiteSubtitle string `json:"site_subtitle"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
APIBaseURL string `json:"api_base_url"`
|
APIBaseURL string `json:"api_base_url"`
|
||||||
ContactInfo string `json:"contact_info"`
|
ContactInfo string `json:"contact_info"`
|
||||||
DocURL string `json:"doc_url"`
|
DocURL string `json:"doc_url"`
|
||||||
HomeContent string `json:"home_content"`
|
HomeContent string `json:"home_content"`
|
||||||
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
|
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||||
|
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||||
|
|
||||||
// 默认配置
|
// 默认配置
|
||||||
DefaultConcurrency int `json:"default_concurrency"`
|
DefaultConcurrency int `json:"default_concurrency"`
|
||||||
@@ -194,6 +207,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TOTP 双因素认证参数验证
|
||||||
|
// 只有手动配置了加密密钥才允许启用 TOTP 功能
|
||||||
|
if req.TotpEnabled && !previousSettings.TotpEnabled {
|
||||||
|
// 尝试启用 TOTP,检查加密密钥是否已手动配置
|
||||||
|
if !h.settingService.IsTotpEncryptionKeyConfigured() {
|
||||||
|
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LinuxDo Connect 参数验证
|
// LinuxDo Connect 参数验证
|
||||||
if req.LinuxDoConnectEnabled {
|
if req.LinuxDoConnectEnabled {
|
||||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||||
@@ -223,6 +246,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// “购买订阅”页面配置验证
|
||||||
|
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
|
||||||
|
if req.PurchaseSubscriptionEnabled != nil {
|
||||||
|
purchaseEnabled = *req.PurchaseSubscriptionEnabled
|
||||||
|
}
|
||||||
|
purchaseURL := previousSettings.PurchaseSubscriptionURL
|
||||||
|
if req.PurchaseSubscriptionURL != nil {
|
||||||
|
purchaseURL = strings.TrimSpace(*req.PurchaseSubscriptionURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// - 启用时要求 URL 合法且非空
|
||||||
|
// - 禁用时允许为空;若提供了 URL 也做基本校验,避免误配置
|
||||||
|
if purchaseEnabled {
|
||||||
|
if purchaseURL == "" {
|
||||||
|
response.BadRequest(c, "Purchase Subscription URL is required when enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
|
||||||
|
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if purchaseURL != "" {
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(purchaseURL); err != nil {
|
||||||
|
response.BadRequest(c, "Purchase Subscription URL must be an absolute http(s) URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Ops metrics collector interval validation (seconds).
|
// Ops metrics collector interval validation (seconds).
|
||||||
if req.OpsMetricsIntervalSeconds != nil {
|
if req.OpsMetricsIntervalSeconds != nil {
|
||||||
v := *req.OpsMetricsIntervalSeconds
|
v := *req.OpsMetricsIntervalSeconds
|
||||||
@@ -236,38 +287,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
settings := &service.SystemSettings{
|
settings := &service.SystemSettings{
|
||||||
RegistrationEnabled: req.RegistrationEnabled,
|
RegistrationEnabled: req.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||||
SMTPHost: req.SMTPHost,
|
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||||
SMTPPort: req.SMTPPort,
|
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||||
SMTPUsername: req.SMTPUsername,
|
TotpEnabled: req.TotpEnabled,
|
||||||
SMTPPassword: req.SMTPPassword,
|
SMTPHost: req.SMTPHost,
|
||||||
SMTPFrom: req.SMTPFrom,
|
SMTPPort: req.SMTPPort,
|
||||||
SMTPFromName: req.SMTPFromName,
|
SMTPUsername: req.SMTPUsername,
|
||||||
SMTPUseTLS: req.SMTPUseTLS,
|
SMTPPassword: req.SMTPPassword,
|
||||||
TurnstileEnabled: req.TurnstileEnabled,
|
SMTPFrom: req.SMTPFrom,
|
||||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
SMTPFromName: req.SMTPFromName,
|
||||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
SMTPUseTLS: req.SMTPUseTLS,
|
||||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
TurnstileEnabled: req.TurnstileEnabled,
|
||||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||||
SiteName: req.SiteName,
|
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||||
SiteLogo: req.SiteLogo,
|
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||||
SiteSubtitle: req.SiteSubtitle,
|
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||||
APIBaseURL: req.APIBaseURL,
|
SiteName: req.SiteName,
|
||||||
ContactInfo: req.ContactInfo,
|
SiteLogo: req.SiteLogo,
|
||||||
DocURL: req.DocURL,
|
SiteSubtitle: req.SiteSubtitle,
|
||||||
HomeContent: req.HomeContent,
|
APIBaseURL: req.APIBaseURL,
|
||||||
DefaultConcurrency: req.DefaultConcurrency,
|
ContactInfo: req.ContactInfo,
|
||||||
DefaultBalance: req.DefaultBalance,
|
DocURL: req.DocURL,
|
||||||
EnableModelFallback: req.EnableModelFallback,
|
HomeContent: req.HomeContent,
|
||||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
HideCcsImportButton: req.HideCcsImportButton,
|
||||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||||
FallbackModelGemini: req.FallbackModelGemini,
|
PurchaseSubscriptionURL: purchaseURL,
|
||||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
DefaultConcurrency: req.DefaultConcurrency,
|
||||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
DefaultBalance: req.DefaultBalance,
|
||||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
EnableModelFallback: req.EnableModelFallback,
|
||||||
|
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||||
|
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||||
|
FallbackModelGemini: req.FallbackModelGemini,
|
||||||
|
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||||
|
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||||
|
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||||
OpsMonitoringEnabled: func() bool {
|
OpsMonitoringEnabled: func() bool {
|
||||||
if req.OpsMonitoringEnabled != nil {
|
if req.OpsMonitoringEnabled != nil {
|
||||||
return *req.OpsMonitoringEnabled
|
return *req.OpsMonitoringEnabled
|
||||||
@@ -311,6 +368,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
response.Success(c, dto.SystemSettings{
|
response.Success(c, dto.SystemSettings{
|
||||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||||
|
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||||
|
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||||
|
TotpEnabled: updatedSettings.TotpEnabled,
|
||||||
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
SMTPHost: updatedSettings.SMTPHost,
|
SMTPHost: updatedSettings.SMTPHost,
|
||||||
SMTPPort: updatedSettings.SMTPPort,
|
SMTPPort: updatedSettings.SMTPPort,
|
||||||
SMTPUsername: updatedSettings.SMTPUsername,
|
SMTPUsername: updatedSettings.SMTPUsername,
|
||||||
@@ -332,6 +393,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
ContactInfo: updatedSettings.ContactInfo,
|
ContactInfo: updatedSettings.ContactInfo,
|
||||||
DocURL: updatedSettings.DocURL,
|
DocURL: updatedSettings.DocURL,
|
||||||
HomeContent: updatedSettings.HomeContent,
|
HomeContent: updatedSettings.HomeContent,
|
||||||
|
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||||
|
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||||
|
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||||
DefaultBalance: updatedSettings.DefaultBalance,
|
DefaultBalance: updatedSettings.DefaultBalance,
|
||||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||||
@@ -376,6 +440,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||||
changed = append(changed, "email_verify_enabled")
|
changed = append(changed, "email_verify_enabled")
|
||||||
}
|
}
|
||||||
|
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||||
|
changed = append(changed, "password_reset_enabled")
|
||||||
|
}
|
||||||
|
if before.TotpEnabled != after.TotpEnabled {
|
||||||
|
changed = append(changed, "totp_enabled")
|
||||||
|
}
|
||||||
if before.SMTPHost != after.SMTPHost {
|
if before.SMTPHost != after.SMTPHost {
|
||||||
changed = append(changed, "smtp_host")
|
changed = append(changed, "smtp_host")
|
||||||
}
|
}
|
||||||
@@ -439,6 +509,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.HomeContent != after.HomeContent {
|
if before.HomeContent != after.HomeContent {
|
||||||
changed = append(changed, "home_content")
|
changed = append(changed, "home_content")
|
||||||
}
|
}
|
||||||
|
if before.HideCcsImportButton != after.HideCcsImportButton {
|
||||||
|
changed = append(changed, "hide_ccs_import_button")
|
||||||
|
}
|
||||||
if before.DefaultConcurrency != after.DefaultConcurrency {
|
if before.DefaultConcurrency != after.DefaultConcurrency {
|
||||||
changed = append(changed, "default_concurrency")
|
changed = append(changed, "default_concurrency")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
|
|||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtendSubscriptionRequest represents extend subscription request
|
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
|
||||||
type ExtendSubscriptionRequest struct {
|
type AdjustSubscriptionRequest struct {
|
||||||
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
|
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all subscriptions with pagination and filters
|
// List handles listing all subscriptions with pagination and filters
|
||||||
@@ -77,15 +77,19 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
status := c.Query("status")
|
status := c.Query("status")
|
||||||
|
|
||||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
// Parse sorting parameters
|
||||||
|
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||||
|
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||||
|
|
||||||
|
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||||
for i := range subscriptions {
|
for i := range subscriptions {
|
||||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||||
}
|
}
|
||||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||||
}
|
}
|
||||||
@@ -105,7 +109,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProgress handles getting subscription usage progress
|
// GetProgress handles getting subscription usage progress
|
||||||
@@ -150,7 +154,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||||
}
|
}
|
||||||
|
|
||||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||||
@@ -180,7 +184,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
|||||||
response.Success(c, dto.BulkAssignResultFromService(result))
|
response.Success(c, dto.BulkAssignResultFromService(result))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extend handles extending a subscription
|
// Extend handles adjusting a subscription (extend or shorten)
|
||||||
// POST /api/v1/admin/subscriptions/:id/extend
|
// POST /api/v1/admin/subscriptions/:id/extend
|
||||||
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
@@ -189,7 +193,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req ExtendSubscriptionRequest
|
var req AdjustSubscriptionRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
@@ -201,7 +205,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revoke handles revoking a subscription
|
// Revoke handles revoking a subscription
|
||||||
@@ -239,9 +243,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||||
for i := range subscriptions {
|
for i := range subscriptions {
|
||||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||||
}
|
}
|
||||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||||
}
|
}
|
||||||
@@ -261,9 +265,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||||
for i := range subscriptions {
|
for i := range subscriptions {
|
||||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||||
}
|
}
|
||||||
response.Success(c, out)
|
response.Success(c, out)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.UsageLog, 0, len(records))
|
out := make([]dto.AdminUsageLog, 0, len(records))
|
||||||
for i := range records {
|
for i := range records {
|
||||||
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
|
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make([]dto.User, 0, len(users))
|
out := make([]dto.AdminUser, 0, len(users))
|
||||||
for i := range users {
|
for i := range users {
|
||||||
out = append(out, *dto.UserFromService(&users[i]))
|
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
|
||||||
}
|
}
|
||||||
response.Paginated(c, out, total, page, pageSize)
|
response.Paginated(c, out, total, page, pageSize)
|
||||||
}
|
}
|
||||||
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.UserFromService(user))
|
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create handles creating a new user
|
// Create handles creating a new user
|
||||||
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.UserFromService(user))
|
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update handles updating a user
|
// Update handles updating a user
|
||||||
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.UserFromService(user))
|
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete handles deleting a user
|
// Delete handles deleting a user
|
||||||
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.UserFromService(user))
|
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserAPIKeys handles getting user's API keys
|
// GetUserAPIKeys handles getting user's API keys
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
@@ -18,16 +20,18 @@ type AuthHandler struct {
|
|||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
settingSvc *service.SettingService
|
settingSvc *service.SettingService
|
||||||
promoService *service.PromoService
|
promoService *service.PromoService
|
||||||
|
totpService *service.TotpService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthHandler creates a new AuthHandler
|
// NewAuthHandler creates a new AuthHandler
|
||||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
|
||||||
return &AuthHandler{
|
return &AuthHandler{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
authService: authService,
|
authService: authService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
settingSvc: settingService,
|
settingSvc: settingService,
|
||||||
promoService: promoService,
|
promoService: promoService,
|
||||||
|
totpService: totpService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if TOTP 2FA is enabled for this user
|
||||||
|
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||||
|
// Create a temporary login session for 2FA
|
||||||
|
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to create 2FA session")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, TotpLoginResponse{
|
||||||
|
Requires2FA: true,
|
||||||
|
TempToken: tempToken,
|
||||||
|
UserEmailMasked: service.MaskEmail(user.Email),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, AuthResponse{
|
||||||
|
AccessToken: token,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
User: dto.UserFromService(user),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpLoginResponse represents the response when 2FA is required
|
||||||
|
type TotpLoginResponse struct {
|
||||||
|
Requires2FA bool `json:"requires_2fa"`
|
||||||
|
TempToken string `json:"temp_token,omitempty"`
|
||||||
|
UserEmailMasked string `json:"user_email_masked,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login2FARequest represents the 2FA login request
|
||||||
|
type Login2FARequest struct {
|
||||||
|
TempToken string `json:"temp_token" binding:"required"`
|
||||||
|
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login2FA completes the login with 2FA verification
|
||||||
|
// POST /api/v1/auth/login/2fa
|
||||||
|
func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||||
|
var req Login2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("login_2fa_request",
|
||||||
|
"temp_token_len", len(req.TempToken),
|
||||||
|
"totp_code_len", len(req.TotpCode))
|
||||||
|
|
||||||
|
// Get the login session
|
||||||
|
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
|
||||||
|
if err != nil || session == nil {
|
||||||
|
tokenPrefix := ""
|
||||||
|
if len(req.TempToken) >= 8 {
|
||||||
|
tokenPrefix = req.TempToken[:8]
|
||||||
|
}
|
||||||
|
slog.Debug("login_2fa_session_invalid",
|
||||||
|
"temp_token_prefix", tokenPrefix,
|
||||||
|
"error", err)
|
||||||
|
response.BadRequest(c, "Invalid or expired 2FA session")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("login_2fa_session_found",
|
||||||
|
"user_id", session.UserID,
|
||||||
|
"email", session.Email)
|
||||||
|
|
||||||
|
// Verify the TOTP code
|
||||||
|
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
|
||||||
|
slog.Debug("login_2fa_verify_failed",
|
||||||
|
"user_id", session.UserID,
|
||||||
|
"error", err)
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the login session
|
||||||
|
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||||
|
|
||||||
|
// Get the user
|
||||||
|
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the JWT token
|
||||||
|
token, err := h.authService.GenerateToken(user)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalError(c, "Failed to generate token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
response.Success(c, AuthResponse{
|
response.Success(c, AuthResponse{
|
||||||
AccessToken: token,
|
AccessToken: token,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
@@ -195,6 +293,15 @@ type ValidatePromoCodeResponse struct {
|
|||||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||||
// POST /api/v1/auth/validate-promo-code
|
// POST /api/v1/auth/validate-promo-code
|
||||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||||
|
// 检查优惠码功能是否启用
|
||||||
|
if h.settingSvc != nil && !h.settingSvc.IsPromoCodeEnabled(c.Request.Context()) {
|
||||||
|
response.Success(c, ValidatePromoCodeResponse{
|
||||||
|
Valid: false,
|
||||||
|
ErrorCode: "PROMO_CODE_DISABLED",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var req ValidatePromoCodeRequest
|
var req ValidatePromoCodeRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
@@ -238,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
|||||||
BonusAmount: promoCode.BonusAmount,
|
BonusAmount: promoCode.BonusAmount,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForgotPasswordRequest 忘记密码请求
|
||||||
|
type ForgotPasswordRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
TurnstileToken string `json:"turnstile_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForgotPasswordResponse 忘记密码响应
|
||||||
|
type ForgotPasswordResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForgotPassword 请求密码重置
|
||||||
|
// POST /api/v1/auth/forgot-password
|
||||||
|
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||||
|
var req ForgotPasswordRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Turnstile 验证
|
||||||
|
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build frontend base URL from request
|
||||||
|
scheme := "https"
|
||||||
|
if c.Request.TLS == nil {
|
||||||
|
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||||
|
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||||
|
scheme = proto
|
||||||
|
} else {
|
||||||
|
scheme = "http"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||||
|
|
||||||
|
// Request password reset (async)
|
||||||
|
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||||
|
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, ForgotPasswordResponse{
|
||||||
|
Message: "If your email is registered, you will receive a password reset link shortly.",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetPasswordRequest 重置密码请求
|
||||||
|
type ResetPasswordRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
Token string `json:"token" binding:"required"`
|
||||||
|
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetPasswordResponse 重置密码响应
|
||||||
|
type ResetPasswordResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetPassword 重置密码
|
||||||
|
// POST /api/v1/auth/reset-password
|
||||||
|
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||||
|
var req ResetPasswordRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset password
|
||||||
|
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, ResetPasswordResponse{
|
||||||
|
Message: "Your password has been reset successfully. You can now log in with your new password.",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
|
|||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Username: u.Username,
|
Username: u.Username,
|
||||||
Notes: u.Notes,
|
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
Balance: u.Balance,
|
Balance: u.Balance,
|
||||||
Concurrency: u.Concurrency,
|
Concurrency: u.Concurrency,
|
||||||
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserFromServiceAdmin converts a service User to DTO for admin users.
|
||||||
|
// It includes notes - user-facing endpoints must not use this.
|
||||||
|
func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||||
|
if u == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
base := UserFromService(u)
|
||||||
|
if base == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &AdminUser{
|
||||||
|
User: *base,
|
||||||
|
Notes: u.Notes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func APIKeyFromService(k *service.APIKey) *APIKey {
|
func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||||
if k == nil {
|
if k == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -72,36 +87,29 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
|||||||
if g == nil {
|
if g == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &Group{
|
out := groupFromServiceBase(g)
|
||||||
ID: g.ID,
|
return &out
|
||||||
Name: g.Name,
|
|
||||||
Description: g.Description,
|
|
||||||
Platform: g.Platform,
|
|
||||||
RateMultiplier: g.RateMultiplier,
|
|
||||||
IsExclusive: g.IsExclusive,
|
|
||||||
Status: g.Status,
|
|
||||||
SubscriptionType: g.SubscriptionType,
|
|
||||||
DailyLimitUSD: g.DailyLimitUSD,
|
|
||||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
|
||||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
|
||||||
ImagePrice1K: g.ImagePrice1K,
|
|
||||||
ImagePrice2K: g.ImagePrice2K,
|
|
||||||
ImagePrice4K: g.ImagePrice4K,
|
|
||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
|
||||||
ModelRouting: g.ModelRouting,
|
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
|
||||||
CreatedAt: g.CreatedAt,
|
|
||||||
UpdatedAt: g.UpdatedAt,
|
|
||||||
AccountCount: g.AccountCount,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GroupFromService(g *service.Group) *Group {
|
func GroupFromService(g *service.Group) *Group {
|
||||||
if g == nil {
|
if g == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := GroupFromServiceShallow(g)
|
return GroupFromServiceShallow(g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
|
||||||
|
// It includes internal fields like model_routing and account_count.
|
||||||
|
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||||
|
if g == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := &AdminGroup{
|
||||||
|
Group: groupFromServiceBase(g),
|
||||||
|
ModelRouting: g.ModelRouting,
|
||||||
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
|
AccountCount: g.AccountCount,
|
||||||
|
}
|
||||||
if len(g.AccountGroups) > 0 {
|
if len(g.AccountGroups) > 0 {
|
||||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||||
for i := range g.AccountGroups {
|
for i := range g.AccountGroups {
|
||||||
@@ -112,6 +120,29 @@ func GroupFromService(g *service.Group) *Group {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func groupFromServiceBase(g *service.Group) Group {
|
||||||
|
return Group{
|
||||||
|
ID: g.ID,
|
||||||
|
Name: g.Name,
|
||||||
|
Description: g.Description,
|
||||||
|
Platform: g.Platform,
|
||||||
|
RateMultiplier: g.RateMultiplier,
|
||||||
|
IsExclusive: g.IsExclusive,
|
||||||
|
Status: g.Status,
|
||||||
|
SubscriptionType: g.SubscriptionType,
|
||||||
|
DailyLimitUSD: g.DailyLimitUSD,
|
||||||
|
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||||
|
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||||
|
ImagePrice1K: g.ImagePrice1K,
|
||||||
|
ImagePrice2K: g.ImagePrice2K,
|
||||||
|
ImagePrice4K: g.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
|
CreatedAt: g.CreatedAt,
|
||||||
|
UpdatedAt: g.UpdatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func AccountFromServiceShallow(a *service.Account) *Account {
|
func AccountFromServiceShallow(a *service.Account) *Account {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -273,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
|||||||
if rc == nil {
|
if rc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &RedeemCode{
|
out := redeemCodeFromServiceBase(rc)
|
||||||
|
return &out
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
|
||||||
|
// It includes notes - user-facing endpoints must not use this.
|
||||||
|
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
|
||||||
|
if rc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &AdminRedeemCode{
|
||||||
|
RedeemCode: redeemCodeFromServiceBase(rc),
|
||||||
|
Notes: rc.Notes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||||
|
return RedeemCode{
|
||||||
ID: rc.ID,
|
ID: rc.ID,
|
||||||
Code: rc.Code,
|
Code: rc.Code,
|
||||||
Type: rc.Type,
|
Type: rc.Type,
|
||||||
@@ -281,7 +329,6 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
|||||||
Status: rc.Status,
|
Status: rc.Status,
|
||||||
UsedBy: rc.UsedBy,
|
UsedBy: rc.UsedBy,
|
||||||
UsedAt: rc.UsedAt,
|
UsedAt: rc.UsedAt,
|
||||||
Notes: rc.Notes,
|
|
||||||
CreatedAt: rc.CreatedAt,
|
CreatedAt: rc.CreatedAt,
|
||||||
GroupID: rc.GroupID,
|
GroupID: rc.GroupID,
|
||||||
ValidityDays: rc.ValidityDays,
|
ValidityDays: rc.ValidityDays,
|
||||||
@@ -302,14 +349,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||||
// The account parameter allows caller to control what Account info is included.
|
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||||
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
|
return UsageLog{
|
||||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
|
|
||||||
if l == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
result := &UsageLog{
|
|
||||||
ID: l.ID,
|
ID: l.ID,
|
||||||
UserID: l.UserID,
|
UserID: l.UserID,
|
||||||
APIKeyID: l.APIKeyID,
|
APIKeyID: l.APIKeyID,
|
||||||
@@ -331,7 +373,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
|||||||
TotalCost: l.TotalCost,
|
TotalCost: l.TotalCost,
|
||||||
ActualCost: l.ActualCost,
|
ActualCost: l.ActualCost,
|
||||||
RateMultiplier: l.RateMultiplier,
|
RateMultiplier: l.RateMultiplier,
|
||||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
|
||||||
BillingType: l.BillingType,
|
BillingType: l.BillingType,
|
||||||
Stream: l.Stream,
|
Stream: l.Stream,
|
||||||
DurationMs: l.DurationMs,
|
DurationMs: l.DurationMs,
|
||||||
@@ -342,30 +383,33 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
|||||||
CreatedAt: l.CreatedAt,
|
CreatedAt: l.CreatedAt,
|
||||||
User: UserFromServiceShallow(l.User),
|
User: UserFromServiceShallow(l.User),
|
||||||
APIKey: APIKeyFromService(l.APIKey),
|
APIKey: APIKeyFromService(l.APIKey),
|
||||||
Account: account,
|
|
||||||
Group: GroupFromServiceShallow(l.Group),
|
Group: GroupFromServiceShallow(l.Group),
|
||||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||||
}
|
}
|
||||||
// IP 地址仅对管理员可见
|
|
||||||
if includeIPAddress {
|
|
||||||
result.IPAddress = l.IPAddress
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||||
// It excludes Account details and IP address - users should not see these.
|
// It excludes Account details and IP address - users should not see these.
|
||||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||||
return usageLogFromServiceBase(l, nil, false)
|
if l == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
u := usageLogFromServiceUser(l)
|
||||||
|
return &u
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||||
// It includes minimal Account info (ID, Name only) and IP address.
|
// It includes minimal Account info (ID, Name only) and IP address.
|
||||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
|
return &AdminUsageLog{
|
||||||
|
UsageLog: usageLogFromServiceUser(l),
|
||||||
|
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||||
|
IPAddress: l.IPAddress,
|
||||||
|
Account: AccountSummaryFromService(l.Account),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
|
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
|
||||||
@@ -414,7 +458,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
|||||||
if sub == nil {
|
if sub == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &UserSubscription{
|
out := userSubscriptionFromServiceBase(sub)
|
||||||
|
return &out
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
|
||||||
|
// It includes assignment metadata and notes.
|
||||||
|
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
|
||||||
|
if sub == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &AdminUserSubscription{
|
||||||
|
UserSubscription: userSubscriptionFromServiceBase(sub),
|
||||||
|
AssignedBy: sub.AssignedBy,
|
||||||
|
AssignedAt: sub.AssignedAt,
|
||||||
|
Notes: sub.Notes,
|
||||||
|
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
|
||||||
|
return UserSubscription{
|
||||||
ID: sub.ID,
|
ID: sub.ID,
|
||||||
UserID: sub.UserID,
|
UserID: sub.UserID,
|
||||||
GroupID: sub.GroupID,
|
GroupID: sub.GroupID,
|
||||||
@@ -427,14 +491,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
|||||||
DailyUsageUSD: sub.DailyUsageUSD,
|
DailyUsageUSD: sub.DailyUsageUSD,
|
||||||
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
||||||
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
||||||
AssignedBy: sub.AssignedBy,
|
|
||||||
AssignedAt: sub.AssignedAt,
|
|
||||||
Notes: sub.Notes,
|
|
||||||
CreatedAt: sub.CreatedAt,
|
CreatedAt: sub.CreatedAt,
|
||||||
UpdatedAt: sub.UpdatedAt,
|
UpdatedAt: sub.UpdatedAt,
|
||||||
User: UserFromServiceShallow(sub.User),
|
User: UserFromServiceShallow(sub.User),
|
||||||
Group: GroupFromServiceShallow(sub.Group),
|
Group: GroupFromServiceShallow(sub.Group),
|
||||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -442,9 +502,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
|||||||
if r == nil {
|
if r == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
subs := make([]UserSubscription, 0, len(r.Subscriptions))
|
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
|
||||||
for i := range r.Subscriptions {
|
for i := range r.Subscriptions {
|
||||||
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
|
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
|
||||||
}
|
}
|
||||||
return &BulkAssignResult{
|
return &BulkAssignResult{
|
||||||
SuccessCount: r.SuccessCount,
|
SuccessCount: r.SuccessCount,
|
||||||
|
|||||||
@@ -2,8 +2,12 @@ package dto
|
|||||||
|
|
||||||
// SystemSettings represents the admin settings API response payload.
|
// SystemSettings represents the admin settings API response payload.
|
||||||
type SystemSettings struct {
|
type SystemSettings struct {
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||||
|
|
||||||
SMTPHost string `json:"smtp_host"`
|
SMTPHost string `json:"smtp_host"`
|
||||||
SMTPPort int `json:"smtp_port"`
|
SMTPPort int `json:"smtp_port"`
|
||||||
@@ -22,13 +26,16 @@ type SystemSettings struct {
|
|||||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||||
|
|
||||||
SiteName string `json:"site_name"`
|
SiteName string `json:"site_name"`
|
||||||
SiteLogo string `json:"site_logo"`
|
SiteLogo string `json:"site_logo"`
|
||||||
SiteSubtitle string `json:"site_subtitle"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
APIBaseURL string `json:"api_base_url"`
|
APIBaseURL string `json:"api_base_url"`
|
||||||
ContactInfo string `json:"contact_info"`
|
ContactInfo string `json:"contact_info"`
|
||||||
DocURL string `json:"doc_url"`
|
DocURL string `json:"doc_url"`
|
||||||
HomeContent string `json:"home_content"`
|
HomeContent string `json:"home_content"`
|
||||||
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||||
|
|
||||||
DefaultConcurrency int `json:"default_concurrency"`
|
DefaultConcurrency int `json:"default_concurrency"`
|
||||||
DefaultBalance float64 `json:"default_balance"`
|
DefaultBalance float64 `json:"default_balance"`
|
||||||
@@ -52,19 +59,25 @@ type SystemSettings struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PublicSettings struct {
|
type PublicSettings struct {
|
||||||
RegistrationEnabled bool `json:"registration_enabled"`
|
RegistrationEnabled bool `json:"registration_enabled"`
|
||||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
SiteName string `json:"site_name"`
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
SiteLogo string `json:"site_logo"`
|
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||||
SiteSubtitle string `json:"site_subtitle"`
|
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||||
APIBaseURL string `json:"api_base_url"`
|
SiteName string `json:"site_name"`
|
||||||
ContactInfo string `json:"contact_info"`
|
SiteLogo string `json:"site_logo"`
|
||||||
DocURL string `json:"doc_url"`
|
SiteSubtitle string `json:"site_subtitle"`
|
||||||
HomeContent string `json:"home_content"`
|
APIBaseURL string `json:"api_base_url"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
ContactInfo string `json:"contact_info"`
|
||||||
Version string `json:"version"`
|
DocURL string `json:"doc_url"`
|
||||||
|
HomeContent string `json:"home_content"`
|
||||||
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||||
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ type User struct {
|
|||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Notes string `json:"notes"`
|
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Balance float64 `json:"balance"`
|
Balance float64 `json:"balance"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
@@ -19,6 +18,14 @@ type User struct {
|
|||||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AdminUser 是管理员接口使用的 user DTO(包含敏感/内部字段)。
|
||||||
|
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
|
||||||
|
type AdminUser struct {
|
||||||
|
User
|
||||||
|
|
||||||
|
Notes string `json:"notes"`
|
||||||
|
}
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
@@ -58,13 +65,19 @@ type Group struct {
|
|||||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminGroup 是管理员接口使用的 group DTO(包含敏感/内部字段)。
|
||||||
|
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
|
||||||
|
type AdminGroup struct {
|
||||||
|
Group
|
||||||
|
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
|
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
AccountCount int64 `json:"account_count,omitempty"`
|
AccountCount int64 `json:"account_count,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -180,7 +193,6 @@ type RedeemCode struct {
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
UsedBy *int64 `json:"used_by"`
|
UsedBy *int64 `json:"used_by"`
|
||||||
UsedAt *time.Time `json:"used_at"`
|
UsedAt *time.Time `json:"used_at"`
|
||||||
Notes string `json:"notes"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
@@ -190,6 +202,15 @@ type RedeemCode struct {
|
|||||||
Group *Group `json:"group,omitempty"`
|
Group *Group `json:"group,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AdminRedeemCode 是管理员接口使用的 redeem code DTO(包含 notes 等字段)。
|
||||||
|
// 注意:普通用户接口不得返回 notes 等内部信息。
|
||||||
|
type AdminRedeemCode struct {
|
||||||
|
RedeemCode
|
||||||
|
|
||||||
|
Notes string `json:"notes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。
|
||||||
type UsageLog struct {
|
type UsageLog struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
@@ -209,14 +230,13 @@ type UsageLog struct {
|
|||||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||||
|
|
||||||
InputCost float64 `json:"input_cost"`
|
InputCost float64 `json:"input_cost"`
|
||||||
OutputCost float64 `json:"output_cost"`
|
OutputCost float64 `json:"output_cost"`
|
||||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||||
CacheReadCost float64 `json:"cache_read_cost"`
|
CacheReadCost float64 `json:"cache_read_cost"`
|
||||||
TotalCost float64 `json:"total_cost"`
|
TotalCost float64 `json:"total_cost"`
|
||||||
ActualCost float64 `json:"actual_cost"`
|
ActualCost float64 `json:"actual_cost"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
|
||||||
|
|
||||||
BillingType int8 `json:"billing_type"`
|
BillingType int8 `json:"billing_type"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
@@ -230,18 +250,28 @@ type UsageLog struct {
|
|||||||
// User-Agent
|
// User-Agent
|
||||||
UserAgent *string `json:"user_agent"`
|
UserAgent *string `json:"user_agent"`
|
||||||
|
|
||||||
// IP 地址(仅管理员可见)
|
|
||||||
IPAddress *string `json:"ip_address,omitempty"`
|
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
User *User `json:"user,omitempty"`
|
User *User `json:"user,omitempty"`
|
||||||
APIKey *APIKey `json:"api_key,omitempty"`
|
APIKey *APIKey `json:"api_key,omitempty"`
|
||||||
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
|
|
||||||
Group *Group `json:"group,omitempty"`
|
Group *Group `json:"group,omitempty"`
|
||||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AdminUsageLog 是管理员接口使用的 usage log DTO(包含管理员字段)。
|
||||||
|
type AdminUsageLog struct {
|
||||||
|
UsageLog
|
||||||
|
|
||||||
|
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||||
|
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||||
|
|
||||||
|
// IPAddress 用户请求 IP(仅管理员可见)
|
||||||
|
IPAddress *string `json:"ip_address,omitempty"`
|
||||||
|
|
||||||
|
// Account 最小账号信息(避免泄露敏感字段)
|
||||||
|
Account *AccountSummary `json:"account,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type UsageCleanupFilters struct {
|
type UsageCleanupFilters struct {
|
||||||
StartTime time.Time `json:"start_time"`
|
StartTime time.Time `json:"start_time"`
|
||||||
EndTime time.Time `json:"end_time"`
|
EndTime time.Time `json:"end_time"`
|
||||||
@@ -300,23 +330,30 @@ type UserSubscription struct {
|
|||||||
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
||||||
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
|
User *User `json:"user,omitempty"`
|
||||||
|
Group *Group `json:"group,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminUserSubscription 是管理员接口使用的订阅 DTO(包含分配信息/备注等字段)。
|
||||||
|
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
|
||||||
|
type AdminUserSubscription struct {
|
||||||
|
UserSubscription
|
||||||
|
|
||||||
AssignedBy *int64 `json:"assigned_by"`
|
AssignedBy *int64 `json:"assigned_by"`
|
||||||
AssignedAt time.Time `json:"assigned_at"`
|
AssignedAt time.Time `json:"assigned_at"`
|
||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
|
|
||||||
User *User `json:"user,omitempty"`
|
|
||||||
Group *Group `json:"group,omitempty"`
|
|
||||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type BulkAssignResult struct {
|
type BulkAssignResult struct {
|
||||||
SuccessCount int `json:"success_count"`
|
SuccessCount int `json:"success_count"`
|
||||||
FailedCount int `json:"failed_count"`
|
FailedCount int `json:"failed_count"`
|
||||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
Subscriptions []AdminUserSubscription `json:"subscriptions"`
|
||||||
Errors []string `json:"errors"`
|
Errors []string `json:"errors"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PromoCode 注册优惠码
|
// PromoCode 注册优惠码
|
||||||
|
|||||||
@@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
if account.IsInterceptWarmupEnabled() {
|
||||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
interceptType := detectInterceptType(body)
|
||||||
selection.ReleaseFunc()
|
if interceptType != InterceptTypeNone {
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
if reqStream {
|
||||||
|
sendMockInterceptStream(c, reqModel, interceptType)
|
||||||
|
} else {
|
||||||
|
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if reqStream {
|
|
||||||
sendMockWarmupStream(c, reqModel)
|
|
||||||
} else {
|
|
||||||
sendMockWarmupResponse(c, reqModel)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 3. 获取账号并发槽位
|
||||||
@@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
if account.IsInterceptWarmupEnabled() {
|
||||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
interceptType := detectInterceptType(body)
|
||||||
selection.ReleaseFunc()
|
if interceptType != InterceptTypeNone {
|
||||||
|
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
if reqStream {
|
||||||
|
sendMockInterceptStream(c, reqModel, interceptType)
|
||||||
|
} else {
|
||||||
|
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if reqStream {
|
|
||||||
sendMockWarmupStream(c, reqModel)
|
|
||||||
} else {
|
|
||||||
sendMockWarmupResponse(c, reqModel)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 3. 获取账号并发槽位
|
||||||
@@ -765,17 +771,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
|
// InterceptType 表示请求拦截类型
|
||||||
func isWarmupRequest(body []byte) bool {
|
type InterceptType int
|
||||||
// 快速检查:如果body不包含关键字,直接返回false
|
|
||||||
|
const (
|
||||||
|
InterceptTypeNone InterceptType = iota
|
||||||
|
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||||
|
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||||
|
)
|
||||||
|
|
||||||
|
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
||||||
|
func detectInterceptType(body []byte) InterceptType {
|
||||||
|
// 快速检查:如果不包含任何关键字,直接返回
|
||||||
bodyStr := string(body)
|
bodyStr := string(body)
|
||||||
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
|
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||||||
return false
|
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
|
||||||
|
|
||||||
|
if !hasSuggestionMode && !hasWarmupKeyword {
|
||||||
|
return InterceptTypeNone
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析完整请求
|
// 解析请求(只解析一次)
|
||||||
var req struct {
|
var req struct {
|
||||||
Messages []struct {
|
Messages []struct {
|
||||||
|
Role string `json:"role"`
|
||||||
Content []struct {
|
Content []struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
@@ -786,43 +805,71 @@ func isWarmupRequest(body []byte) bool {
|
|||||||
} `json:"system"`
|
} `json:"system"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
return false
|
return InterceptTypeNone
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查 messages 中的标题提示模式
|
// 检查 SUGGESTION MODE(最后一条 user 消息)
|
||||||
for _, msg := range req.Messages {
|
if hasSuggestionMode && len(req.Messages) > 0 {
|
||||||
for _, content := range msg.Content {
|
lastMsg := req.Messages[len(req.Messages)-1]
|
||||||
if content.Type == "text" {
|
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
|
||||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
lastMsg.Content[0].Type == "text" &&
|
||||||
content.Text == "Warmup" {
|
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
|
||||||
return true
|
return InterceptTypeSuggestionMode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 Warmup 请求
|
||||||
|
if hasWarmupKeyword {
|
||||||
|
// 检查 messages 中的标题提示模式
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
for _, content := range msg.Content {
|
||||||
|
if content.Type == "text" {
|
||||||
|
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||||
|
content.Text == "Warmup" {
|
||||||
|
return InterceptTypeWarmup
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 检查 system 中的标题提取模式
|
||||||
|
for _, sys := range req.System {
|
||||||
|
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||||
|
return InterceptTypeWarmup
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查 system 中的标题提取模式
|
return InterceptTypeNone
|
||||||
for _, system := range req.System {
|
|
||||||
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
|
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
|
||||||
func sendMockWarmupStream(c *gin.Context, model string) {
|
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
c.Header("X-Accel-Buffering", "no")
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
// 根据拦截类型决定响应内容
|
||||||
|
var msgID string
|
||||||
|
var outputTokens int
|
||||||
|
var textDeltas []string
|
||||||
|
|
||||||
|
switch interceptType {
|
||||||
|
case InterceptTypeSuggestionMode:
|
||||||
|
msgID = "msg_mock_suggestion"
|
||||||
|
outputTokens = 1
|
||||||
|
textDeltas = []string{""} // 空内容
|
||||||
|
default: // InterceptTypeWarmup
|
||||||
|
msgID = "msg_mock_warmup"
|
||||||
|
outputTokens = 2
|
||||||
|
textDeltas = []string{"New", " Conversation"}
|
||||||
|
}
|
||||||
|
|
||||||
// Build message_start event with proper JSON marshaling
|
// Build message_start event with proper JSON marshaling
|
||||||
messageStart := map[string]any{
|
messageStart := map[string]any{
|
||||||
"type": "message_start",
|
"type": "message_start",
|
||||||
"message": map[string]any{
|
"message": map[string]any{
|
||||||
"id": "msg_mock_warmup",
|
"id": msgID,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"model": model,
|
"model": model,
|
||||||
@@ -837,16 +884,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
|||||||
}
|
}
|
||||||
messageStartJSON, _ := json.Marshal(messageStart)
|
messageStartJSON, _ := json.Marshal(messageStart)
|
||||||
|
|
||||||
|
// Build events
|
||||||
events := []string{
|
events := []string{
|
||||||
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
||||||
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
|
||||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
|
||||||
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
|
|
||||||
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
|
|
||||||
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add text deltas
|
||||||
|
for _, text := range textDeltas {
|
||||||
|
delta := map[string]any{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]string{
|
||||||
|
"type": "text_delta",
|
||||||
|
"text": text,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
deltaJSON, _ := json.Marshal(delta)
|
||||||
|
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add final events
|
||||||
|
messageDelta := map[string]any{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": map[string]any{
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": nil,
|
||||||
|
},
|
||||||
|
"usage": map[string]int{
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": outputTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||||
|
|
||||||
|
events = append(events,
|
||||||
|
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||||||
|
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
|
||||||
|
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
|
||||||
|
)
|
||||||
|
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
_, _ = c.Writer.WriteString(event + "\n\n")
|
_, _ = c.Writer.WriteString(event + "\n\n")
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
@@ -854,18 +931,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
||||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||||||
|
var msgID, text string
|
||||||
|
var outputTokens int
|
||||||
|
|
||||||
|
switch interceptType {
|
||||||
|
case InterceptTypeSuggestionMode:
|
||||||
|
msgID = "msg_mock_suggestion"
|
||||||
|
text = ""
|
||||||
|
outputTokens = 1
|
||||||
|
default: // InterceptTypeWarmup
|
||||||
|
msgID = "msg_mock_warmup"
|
||||||
|
text = "New Conversation"
|
||||||
|
outputTokens = 2
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"id": "msg_mock_warmup",
|
"id": msgID,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"model": model,
|
"model": model,
|
||||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
"content": []gin.H{{"type": "text", "text": text}},
|
||||||
"stop_reason": "end_turn",
|
"stop_reason": "end_turn",
|
||||||
"usage": gin.H{
|
"usage": gin.H{
|
||||||
"input_tokens": 10,
|
"input_tokens": 10,
|
||||||
"output_tokens": 2,
|
"output_tokens": outputTokens,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractGeminiCLISessionHash(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
privilegedUserID string
|
||||||
|
wantEmpty bool
|
||||||
|
wantHash string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with privileged-user-id and tmp dir",
|
||||||
|
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||||
|
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||||
|
wantEmpty: false,
|
||||||
|
wantHash: func() string {
|
||||||
|
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
|
||||||
|
hash := sha256.Sum256([]byte(combined))
|
||||||
|
return hex.EncodeToString(hash[:])
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without privileged-user-id but with tmp dir",
|
||||||
|
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||||
|
privilegedUserID: "",
|
||||||
|
wantEmpty: false,
|
||||||
|
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without tmp dir",
|
||||||
|
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
|
||||||
|
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||||
|
wantEmpty: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty body",
|
||||||
|
body: "",
|
||||||
|
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||||
|
wantEmpty: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// 创建测试上下文
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("POST", "/test", nil)
|
||||||
|
if tt.privilegedUserID != "" {
|
||||||
|
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用函数
|
||||||
|
result := extractGeminiCLISessionHash(c, []byte(tt.body))
|
||||||
|
|
||||||
|
// 验证结果
|
||||||
|
if tt.wantEmpty {
|
||||||
|
require.Empty(t, result, "expected empty session hash")
|
||||||
|
} else {
|
||||||
|
require.NotEmpty(t, result, "expected non-empty session hash")
|
||||||
|
require.Equal(t, tt.wantHash, result, "session hash mismatch")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiCLITmpDirRegex(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantMatch bool
|
||||||
|
wantHash string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid tmp dir path",
|
||||||
|
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||||
|
wantMatch: true,
|
||||||
|
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid tmp dir path in text",
|
||||||
|
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
|
||||||
|
wantMatch: true,
|
||||||
|
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid hash length",
|
||||||
|
input: "/Users/ianshaw/.gemini/tmp/abc123",
|
||||||
|
wantMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no tmp dir",
|
||||||
|
input: "Hello world",
|
||||||
|
wantMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
|
||||||
|
if tt.wantMatch {
|
||||||
|
require.NotNil(t, match, "expected regex to match")
|
||||||
|
require.Len(t, match, 2, "expected 2 capture groups")
|
||||||
|
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
|
||||||
|
} else {
|
||||||
|
require.Nil(t, match, "expected regex not to match")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,15 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,6 +23,17 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||||
|
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||||
|
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||||
|
|
||||||
|
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
||||||
|
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return geminiCLITmpDirRegex.Match(body)
|
||||||
|
}
|
||||||
|
|
||||||
// GeminiV1BetaListModels proxies:
|
// GeminiV1BetaListModels proxies:
|
||||||
// GET /v1beta/models
|
// GET /v1beta/models
|
||||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||||
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3) select account (sticky session based on request body)
|
// 3) select account (sticky session based on request body)
|
||||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
// 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||||
|
if sessionHash == "" {
|
||||||
|
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||||
|
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||||
|
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||||
|
}
|
||||||
sessionKey := sessionHash
|
sessionKey := sessionHash
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
sessionKey = "gemini:" + sessionHash
|
sessionKey = "gemini:" + sessionHash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
|
||||||
|
var sessionBoundAccountID int64
|
||||||
|
if sessionKey != "" {
|
||||||
|
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||||
|
}
|
||||||
|
isCLI := isGeminiCLIRequest(c, body)
|
||||||
|
cleanedForUnknownBinding := false
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
|
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||||
|
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||||
|
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||||
|
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||||
|
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||||
|
sessionBoundAccountID = account.ID
|
||||||
|
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||||
|
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
||||||
|
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||||
|
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
||||||
|
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||||
|
cleanedForUnknownBinding = true
|
||||||
|
sessionBoundAccountID = account.ID
|
||||||
|
} else if sessionBoundAccountID == 0 {
|
||||||
|
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
|
||||||
|
sessionBoundAccountID = account.ID
|
||||||
|
}
|
||||||
|
|
||||||
// 4) account concurrency slot
|
// 4) account concurrency slot
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
if !selection.Acquired {
|
if !selection.Acquired {
|
||||||
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||||
|
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||||
|
//
|
||||||
|
// 会话标识生成策略:
|
||||||
|
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
|
||||||
|
// 2. 从 header 中提取 privileged-user-id(UUID)
|
||||||
|
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
|
||||||
|
//
|
||||||
|
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
|
||||||
|
//
|
||||||
|
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
|
||||||
|
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
|
||||||
|
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
|
||||||
|
// 1. 从请求体中提取 tmp 目录哈希
|
||||||
|
match := geminiCLITmpDirRegex.FindSubmatch(body)
|
||||||
|
if len(match) < 2 {
|
||||||
|
return "" // 没有找到 tmp 目录,不使用粘性会话
|
||||||
|
}
|
||||||
|
tmpDirHash := string(match[1])
|
||||||
|
|
||||||
|
// 2. 提取 privileged-user-id
|
||||||
|
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
|
||||||
|
|
||||||
|
// 3. 组合生成最终的 session hash
|
||||||
|
if privilegedUserID != "" {
|
||||||
|
// 组合两个标识符:privileged-user-id + tmp 目录哈希
|
||||||
|
combined := privilegedUserID + ":" + tmpDirHash
|
||||||
|
hash := sha256.Sum256([]byte(combined))
|
||||||
|
return hex.EncodeToString(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
|
||||||
|
return tmpDirHash
|
||||||
|
}
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ type Handlers struct {
|
|||||||
Gateway *GatewayHandler
|
Gateway *GatewayHandler
|
||||||
OpenAIGateway *OpenAIGatewayHandler
|
OpenAIGateway *OpenAIGatewayHandler
|
||||||
Setting *SettingHandler
|
Setting *SettingHandler
|
||||||
|
Totp *TotpHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildInfo contains build-time information
|
// BuildInfo contains build-time information
|
||||||
|
|||||||
@@ -32,18 +32,24 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, dto.PublicSettings{
|
response.Success(c, dto.PublicSettings{
|
||||||
RegistrationEnabled: settings.RegistrationEnabled,
|
RegistrationEnabled: settings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
TurnstileEnabled: settings.TurnstileEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
SiteName: settings.SiteName,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
SiteLogo: settings.SiteLogo,
|
TurnstileEnabled: settings.TurnstileEnabled,
|
||||||
SiteSubtitle: settings.SiteSubtitle,
|
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||||
APIBaseURL: settings.APIBaseURL,
|
SiteName: settings.SiteName,
|
||||||
ContactInfo: settings.ContactInfo,
|
SiteLogo: settings.SiteLogo,
|
||||||
DocURL: settings.DocURL,
|
SiteSubtitle: settings.SiteSubtitle,
|
||||||
HomeContent: settings.HomeContent,
|
APIBaseURL: settings.APIBaseURL,
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
ContactInfo: settings.ContactInfo,
|
||||||
Version: h.version,
|
DocURL: settings.DocURL,
|
||||||
|
HomeContent: settings.HomeContent,
|
||||||
|
HideCcsImportButton: settings.HideCcsImportButton,
|
||||||
|
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||||
|
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||||
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
|
Version: h.version,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
181
backend/internal/handler/totp_handler.go
Normal file
181
backend/internal/handler/totp_handler.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TotpHandler handles TOTP-related requests
|
||||||
|
type TotpHandler struct {
|
||||||
|
totpService *service.TotpService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTotpHandler creates a new TotpHandler
|
||||||
|
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
|
||||||
|
return &TotpHandler{
|
||||||
|
totpService: totpService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpStatusResponse represents the TOTP status response
|
||||||
|
type TotpStatusResponse struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
|
||||||
|
FeatureEnabled bool `json:"feature_enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStatus returns the TOTP status for the current user
|
||||||
|
// GET /api/v1/user/totp/status
|
||||||
|
func (h *TotpHandler) GetStatus(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := TotpStatusResponse{
|
||||||
|
Enabled: status.Enabled,
|
||||||
|
FeatureEnabled: status.FeatureEnabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.EnabledAt != nil {
|
||||||
|
ts := status.EnabledAt.Unix()
|
||||||
|
resp.EnabledAt = &ts
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSetupRequest represents the request to initiate TOTP setup
|
||||||
|
type TotpSetupRequest struct {
|
||||||
|
EmailCode string `json:"email_code"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpSetupResponse represents the TOTP setup response
|
||||||
|
type TotpSetupResponse struct {
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
QRCodeURL string `json:"qr_code_url"`
|
||||||
|
SetupToken string `json:"setup_token"`
|
||||||
|
Countdown int `json:"countdown"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitiateSetup starts the TOTP setup process
|
||||||
|
// POST /api/v1/user/totp/setup
|
||||||
|
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req TotpSetupRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
// Allow empty body (optional params)
|
||||||
|
req = TotpSetupRequest{}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, TotpSetupResponse{
|
||||||
|
Secret: result.Secret,
|
||||||
|
QRCodeURL: result.QRCodeURL,
|
||||||
|
SetupToken: result.SetupToken,
|
||||||
|
Countdown: result.Countdown,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpEnableRequest represents the request to enable TOTP
|
||||||
|
type TotpEnableRequest struct {
|
||||||
|
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||||
|
SetupToken string `json:"setup_token" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable completes the TOTP setup
|
||||||
|
// POST /api/v1/user/totp/enable
|
||||||
|
func (h *TotpHandler) Enable(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req TotpEnableRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotpDisableRequest represents the request to disable TOTP
|
||||||
|
type TotpDisableRequest struct {
|
||||||
|
EmailCode string `json:"email_code"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable disables TOTP for the current user
|
||||||
|
// POST /api/v1/user/totp/disable
|
||||||
|
func (h *TotpHandler) Disable(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req TotpDisableRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVerificationMethod returns the verification method for TOTP operations
|
||||||
|
// GET /api/v1/user/totp/verification-method
|
||||||
|
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
|
||||||
|
method := h.totpService.GetVerificationMethod(c.Request.Context())
|
||||||
|
response.Success(c, method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendVerifyCode sends an email verification code for TOTP operations
|
||||||
|
// POST /api/v1/user/totp/send-code
|
||||||
|
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清空notes字段,普通用户不应看到备注
|
|
||||||
userData.Notes = ""
|
|
||||||
|
|
||||||
response.Success(c, dto.UserFromService(userData))
|
response.Success(c, dto.UserFromService(userData))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清空notes字段,普通用户不应看到备注
|
|
||||||
updatedUser.Notes = ""
|
|
||||||
|
|
||||||
response.Success(c, dto.UserFromService(updatedUser))
|
response.Success(c, dto.UserFromService(updatedUser))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ func ProvideHandlers(
|
|||||||
gatewayHandler *GatewayHandler,
|
gatewayHandler *GatewayHandler,
|
||||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||||
settingHandler *SettingHandler,
|
settingHandler *SettingHandler,
|
||||||
|
totpHandler *TotpHandler,
|
||||||
) *Handlers {
|
) *Handlers {
|
||||||
return &Handlers{
|
return &Handlers{
|
||||||
Auth: authHandler,
|
Auth: authHandler,
|
||||||
@@ -82,6 +83,7 @@ func ProvideHandlers(
|
|||||||
Gateway: gatewayHandler,
|
Gateway: gatewayHandler,
|
||||||
OpenAIGateway: openaiGatewayHandler,
|
OpenAIGateway: openaiGatewayHandler,
|
||||||
Setting: settingHandler,
|
Setting: settingHandler,
|
||||||
|
Totp: totpHandler,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewSubscriptionHandler,
|
NewSubscriptionHandler,
|
||||||
NewGatewayHandler,
|
NewGatewayHandler,
|
||||||
NewOpenAIGatewayHandler,
|
NewOpenAIGatewayHandler,
|
||||||
|
NewTotpHandler,
|
||||||
ProvideSettingHandler,
|
ProvideSettingHandler,
|
||||||
|
|
||||||
// Admin handlers
|
// Admin handlers
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ const (
|
|||||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||||
|
|
||||||
// User-Agent(与 Antigravity-Manager 保持一致)
|
// User-Agent(与 Antigravity-Manager 保持一致)
|
||||||
UserAgent = "antigravity/1.11.9 windows/amd64"
|
UserAgent = "antigravity/1.15.8 windows/amd64"
|
||||||
|
|
||||||
// Session 过期时间
|
// Session 过期时间
|
||||||
SessionTTL = 30 * time.Minute
|
SessionTTL = 30 * time.Minute
|
||||||
|
|||||||
@@ -7,13 +7,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
|||||||
Text: block.Thinking,
|
Text: block.Thinking,
|
||||||
Thought: true,
|
Thought: true,
|
||||||
}
|
}
|
||||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
// signature 处理:
|
||||||
if block.Signature != "" {
|
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||||
|
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||||
|
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||||
part.ThoughtSignature = block.Signature
|
part.ThoughtSignature = block.Signature
|
||||||
} else if !allowDummyThought {
|
} else if !allowDummyThought {
|
||||||
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
||||||
@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
// tool_use 的 signature 处理:
|
// tool_use 的 signature 处理:
|
||||||
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
|
// - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
|
||||||
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
|
// - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
|
||||||
if allowDummyThought {
|
if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
|
||||||
part.ThoughtSignature = dummyThoughtSignature
|
|
||||||
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
|
|
||||||
part.ThoughtSignature = block.Signature
|
part.ThoughtSignature = block.Signature
|
||||||
|
} else if allowDummyThought {
|
||||||
|
part.ThoughtSignature = dummyThoughtSignature
|
||||||
}
|
}
|
||||||
parts = append(parts, part)
|
parts = append(parts, part)
|
||||||
|
|
||||||
@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 清理 JSON Schema
|
// 清理 JSON Schema
|
||||||
params := cleanJSONSchema(inputSchema)
|
// 1. 深度清理 [undefined] 值
|
||||||
|
DeepCleanUndefined(inputSchema)
|
||||||
|
// 2. 转换为符合 Gemini v1internal 的 schema
|
||||||
|
params := CleanJSONSchema(inputSchema)
|
||||||
// 为 nil schema 提供默认值
|
// 为 nil schema 提供默认值
|
||||||
if params == nil {
|
if params == nil {
|
||||||
params = map[string]any{
|
params = map[string]any{
|
||||||
"type": "OBJECT",
|
"type": "object", // lowercase type
|
||||||
"properties": map[string]any{},
|
"properties": map[string]any{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
|||||||
FunctionDeclarations: funcDecls,
|
FunctionDeclarations: funcDecls,
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
|
|
||||||
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
|
|
||||||
func cleanJSONSchema(schema map[string]any) map[string]any {
|
|
||||||
if schema == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
cleaned := cleanSchemaValue(schema, "$")
|
|
||||||
result, ok := cleaned.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确保有 type 字段(默认 OBJECT)
|
|
||||||
if _, hasType := result["type"]; !hasType {
|
|
||||||
result["type"] = "OBJECT"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确保有 properties 字段(默认空对象)
|
|
||||||
if _, hasProps := result["properties"]; !hasProps {
|
|
||||||
result["properties"] = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 required 中的字段都存在于 properties 中
|
|
||||||
if required, ok := result["required"].([]any); ok {
|
|
||||||
if props, ok := result["properties"].(map[string]any); ok {
|
|
||||||
validRequired := make([]any, 0, len(required))
|
|
||||||
for _, r := range required {
|
|
||||||
if reqName, ok := r.(string); ok {
|
|
||||||
if _, exists := props[reqName]; exists {
|
|
||||||
validRequired = append(validRequired, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(validRequired) > 0 {
|
|
||||||
result["required"] = validRequired
|
|
||||||
} else {
|
|
||||||
delete(result, "required")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
var schemaValidationKeys = map[string]bool{
|
|
||||||
"minLength": true,
|
|
||||||
"maxLength": true,
|
|
||||||
"pattern": true,
|
|
||||||
"minimum": true,
|
|
||||||
"maximum": true,
|
|
||||||
"exclusiveMinimum": true,
|
|
||||||
"exclusiveMaximum": true,
|
|
||||||
"multipleOf": true,
|
|
||||||
"uniqueItems": true,
|
|
||||||
"minItems": true,
|
|
||||||
"maxItems": true,
|
|
||||||
"minProperties": true,
|
|
||||||
"maxProperties": true,
|
|
||||||
"patternProperties": true,
|
|
||||||
"propertyNames": true,
|
|
||||||
"dependencies": true,
|
|
||||||
"dependentSchemas": true,
|
|
||||||
"dependentRequired": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
var warnedSchemaKeys sync.Map
|
|
||||||
|
|
||||||
func schemaCleaningWarningsEnabled() bool {
|
|
||||||
// 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
|
|
||||||
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
|
|
||||||
switch strings.ToLower(v) {
|
|
||||||
case "1", "true", "yes", "on":
|
|
||||||
return true
|
|
||||||
case "0", "false", "no", "off":
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 默认:非 release 模式下输出(debug/test)
|
|
||||||
return gin.Mode() != gin.ReleaseMode
|
|
||||||
}
|
|
||||||
|
|
||||||
func warnSchemaKeyRemovedOnce(key, path string) {
|
|
||||||
if !schemaCleaningWarningsEnabled() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !schemaValidationKeys[key] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// excludedSchemaKeys 不支持的 schema 字段
|
|
||||||
// 基于 Claude API (Vertex AI) 的实际支持情况
|
|
||||||
// 支持: type, description, enum, properties, required, additionalProperties, items
|
|
||||||
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
|
|
||||||
var excludedSchemaKeys = map[string]bool{
|
|
||||||
// 元 schema 字段
|
|
||||||
"$schema": true,
|
|
||||||
"$id": true,
|
|
||||||
"$ref": true,
|
|
||||||
|
|
||||||
// 字符串验证(Gemini 不支持)
|
|
||||||
"minLength": true,
|
|
||||||
"maxLength": true,
|
|
||||||
"pattern": true,
|
|
||||||
|
|
||||||
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
|
|
||||||
"minimum": true,
|
|
||||||
"maximum": true,
|
|
||||||
"exclusiveMinimum": true,
|
|
||||||
"exclusiveMaximum": true,
|
|
||||||
"multipleOf": true,
|
|
||||||
|
|
||||||
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
|
|
||||||
"uniqueItems": true,
|
|
||||||
"minItems": true,
|
|
||||||
"maxItems": true,
|
|
||||||
|
|
||||||
// 组合 schema(Gemini 不支持)
|
|
||||||
"oneOf": true,
|
|
||||||
"anyOf": true,
|
|
||||||
"allOf": true,
|
|
||||||
"not": true,
|
|
||||||
"if": true,
|
|
||||||
"then": true,
|
|
||||||
"else": true,
|
|
||||||
"$defs": true,
|
|
||||||
"definitions": true,
|
|
||||||
|
|
||||||
// 对象验证(仅保留 properties/required/additionalProperties)
|
|
||||||
"minProperties": true,
|
|
||||||
"maxProperties": true,
|
|
||||||
"patternProperties": true,
|
|
||||||
"propertyNames": true,
|
|
||||||
"dependencies": true,
|
|
||||||
"dependentSchemas": true,
|
|
||||||
"dependentRequired": true,
|
|
||||||
|
|
||||||
// 其他不支持的字段
|
|
||||||
"default": true,
|
|
||||||
"const": true,
|
|
||||||
"examples": true,
|
|
||||||
"deprecated": true,
|
|
||||||
"readOnly": true,
|
|
||||||
"writeOnly": true,
|
|
||||||
"contentMediaType": true,
|
|
||||||
"contentEncoding": true,
|
|
||||||
|
|
||||||
// Claude 特有字段
|
|
||||||
"strict": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanSchemaValue 递归清理 schema 值
|
|
||||||
func cleanSchemaValue(value any, path string) any {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case map[string]any:
|
|
||||||
result := make(map[string]any)
|
|
||||||
for k, val := range v {
|
|
||||||
// 跳过不支持的字段
|
|
||||||
if excludedSchemaKeys[k] {
|
|
||||||
warnSchemaKeyRemovedOnce(k, path)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 特殊处理 type 字段
|
|
||||||
if k == "type" {
|
|
||||||
result[k] = cleanTypeValue(val)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
|
|
||||||
if k == "format" {
|
|
||||||
if formatStr, ok := val.(string); ok {
|
|
||||||
// Gemini 只支持 date-time, date, time
|
|
||||||
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
|
|
||||||
result[k] = val
|
|
||||||
}
|
|
||||||
// 其他 format 值直接跳过
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
|
|
||||||
if k == "additionalProperties" {
|
|
||||||
if boolVal, ok := val.(bool); ok {
|
|
||||||
result[k] = boolVal
|
|
||||||
} else {
|
|
||||||
// 如果是 schema 对象,转换为 false(更安全的默认值)
|
|
||||||
result[k] = false
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 递归清理所有值
|
|
||||||
result[k] = cleanSchemaValue(val, path+"."+k)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
// 递归处理数组中的每个元素
|
|
||||||
cleaned := make([]any, 0, len(v))
|
|
||||||
for i, item := range v {
|
|
||||||
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
|
|
||||||
}
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
default:
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanTypeValue 处理 type 字段,转换为大写
|
|
||||||
func cleanTypeValue(value any) any {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
return strings.ToUpper(v)
|
|
||||||
case []any:
|
|
||||||
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
|
|
||||||
for _, t := range v {
|
|
||||||
if ts, ok := t.(string); ok && ts != "null" {
|
|
||||||
return strings.ToUpper(ts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 如果只有 null,返回 STRING
|
|
||||||
return "STRING"
|
|
||||||
default:
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
|||||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
|
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
|
||||||
]`
|
]`
|
||||||
|
|
||||||
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
|
t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
|
||||||
toolIDToName := make(map[string]string)
|
toolIDToName := make(map[string]string)
|
||||||
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
|
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
|||||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||||
}
|
}
|
||||||
|
if parts[0].ThoughtSignature != "sig_tool_abc" {
|
||||||
|
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
|
||||||
|
contentNoSig := `[
|
||||||
|
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
|
||||||
|
]`
|
||||||
|
toolIDToName := make(map[string]string)
|
||||||
|
parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildParts() error = %v", err)
|
||||||
|
}
|
||||||
|
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||||
|
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||||
|
}
|
||||||
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
||||||
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package antigravity
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
|
|||||||
v1Resp.Response = directResp
|
v1Resp.Response = directResp
|
||||||
v1Resp.ResponseID = directResp.ResponseID
|
v1Resp.ResponseID = directResp.ResponseID
|
||||||
v1Resp.ModelVersion = directResp.ModelVersion
|
v1Resp.ModelVersion = directResp.ModelVersion
|
||||||
|
} else if len(v1Resp.Response.Candidates) == 0 {
|
||||||
|
// 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
|
||||||
|
var directResp GeminiResponse
|
||||||
|
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
|
||||||
|
return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
|
||||||
|
}
|
||||||
|
v1Resp.Response = directResp
|
||||||
|
v1Resp.ResponseID = directResp.ResponseID
|
||||||
|
v1Resp.ModelVersion = directResp.ModelVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用处理器转换
|
// 使用处理器转换
|
||||||
@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
|||||||
p.trailingSignature = ""
|
p.trailingSignature = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
p.textBuilder += part.Text
|
// 非空 text 带签名 - 特殊处理:先输出 text,再输出空 thinking 块
|
||||||
|
|
||||||
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
|
|
||||||
if signature != "" {
|
if signature != "" {
|
||||||
p.flushText()
|
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||||
|
Type: "text",
|
||||||
|
Text: part.Text,
|
||||||
|
})
|
||||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||||
Type: "thinking",
|
Type: "thinking",
|
||||||
Thinking: "",
|
Thinking: "",
|
||||||
Signature: signature,
|
Signature: signature,
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
// 普通 text (无签名) - 累积到 builder
|
||||||
|
p.textBuilder += part.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
|||||||
var finishReason string
|
var finishReason string
|
||||||
if len(geminiResp.Candidates) > 0 {
|
if len(geminiResp.Candidates) > 0 {
|
||||||
finishReason = geminiResp.Candidates[0].FinishReason
|
finishReason = geminiResp.Candidates[0].FinishReason
|
||||||
|
if finishReason == "MALFORMED_FUNCTION_CALL" {
|
||||||
|
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
|
||||||
|
if geminiResp.Candidates[0].Content != nil {
|
||||||
|
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
|
||||||
|
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stopReason := "end_turn"
|
stopReason := "end_turn"
|
||||||
|
|||||||
519
backend/internal/pkg/antigravity/schema_cleaner.go
Normal file
519
backend/internal/pkg/antigravity/schema_cleaner.go
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
|
||||||
|
// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
|
||||||
|
// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
|
||||||
|
func CleanJSONSchema(schema map[string]any) map[string]any {
|
||||||
|
if schema == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// 0. 预处理:展开 $ref (Schema Flattening)
|
||||||
|
// (Go map 是引用的,直接修改 schema)
|
||||||
|
flattenRefs(schema, extractDefs(schema))
|
||||||
|
|
||||||
|
// 递归清理
|
||||||
|
cleaned := cleanJSONSchemaRecursive(schema)
|
||||||
|
result, ok := cleaned.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDefs 提取并移除定义的 helper
|
||||||
|
func extractDefs(schema map[string]any) map[string]any {
|
||||||
|
defs := make(map[string]any)
|
||||||
|
if d, ok := schema["$defs"].(map[string]any); ok {
|
||||||
|
for k, v := range d {
|
||||||
|
defs[k] = v
|
||||||
|
}
|
||||||
|
delete(schema, "$defs")
|
||||||
|
}
|
||||||
|
if d, ok := schema["definitions"].(map[string]any); ok {
|
||||||
|
for k, v := range d {
|
||||||
|
defs[k] = v
|
||||||
|
}
|
||||||
|
delete(schema, "definitions")
|
||||||
|
}
|
||||||
|
return defs
|
||||||
|
}
|
||||||
|
|
||||||
|
// flattenRefs 递归展开 $ref
|
||||||
|
func flattenRefs(schema map[string]any, defs map[string]any) {
|
||||||
|
if len(defs) == 0 {
|
||||||
|
return // 无需展开
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查并替换 $ref
|
||||||
|
if ref, ok := schema["$ref"].(string); ok {
|
||||||
|
delete(schema, "$ref")
|
||||||
|
// 解析引用名 (例如 #/$defs/MyType -> MyType)
|
||||||
|
parts := strings.Split(ref, "/")
|
||||||
|
refName := parts[len(parts)-1]
|
||||||
|
|
||||||
|
if defSchema, exists := defs[refName]; exists {
|
||||||
|
if defMap, ok := defSchema.(map[string]any); ok {
|
||||||
|
// 合并定义内容 (不覆盖现有 key)
|
||||||
|
for k, v := range defMap {
|
||||||
|
if _, has := schema[k]; !has {
|
||||||
|
schema[k] = deepCopy(v) // 需深拷贝避免共享引用
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 递归处理刚刚合并进来的内容
|
||||||
|
flattenRefs(schema, defs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 遍历子节点
|
||||||
|
for _, v := range schema {
|
||||||
|
if subMap, ok := v.(map[string]any); ok {
|
||||||
|
flattenRefs(subMap, defs)
|
||||||
|
} else if subArr, ok := v.([]any); ok {
|
||||||
|
for _, item := range subArr {
|
||||||
|
if itemMap, ok := item.(map[string]any); ok {
|
||||||
|
flattenRefs(itemMap, defs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
|
||||||
|
func deepCopy(src any) any {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch v := src.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
dst := make(map[string]any)
|
||||||
|
for k, val := range v {
|
||||||
|
dst[k] = deepCopy(val)
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
case []any:
|
||||||
|
dst := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
dst[i] = deepCopy(val)
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
default:
|
||||||
|
return src
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanJSONSchemaRecursive 递归核心清理逻辑
|
||||||
|
// 返回处理后的值 (通常是 input map,但可能修改内部结构)
|
||||||
|
func cleanJSONSchemaRecursive(value any) any {
|
||||||
|
schemaMap, ok := value.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0. [NEW] 合并 allOf
|
||||||
|
mergeAllOf(schemaMap)
|
||||||
|
|
||||||
|
// 1. [CRITICAL] 深度递归处理子项
|
||||||
|
if props, ok := schemaMap["properties"].(map[string]any); ok {
|
||||||
|
for _, v := range props {
|
||||||
|
cleanJSONSchemaRecursive(v)
|
||||||
|
}
|
||||||
|
// Go 中不需要像 Rust 那样显式处理 nullable_keys remove required,
|
||||||
|
// 因为我们在子项处理中会正确设置 type 和 description
|
||||||
|
} else if items, ok := schemaMap["items"]; ok {
|
||||||
|
// [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
|
||||||
|
if itemsArr, ok := items.([]any); ok {
|
||||||
|
// 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
|
||||||
|
best := extractBestSchemaFromUnion(itemsArr)
|
||||||
|
if best == nil {
|
||||||
|
// 回退到通用字符串
|
||||||
|
best = map[string]any{"type": "string"}
|
||||||
|
}
|
||||||
|
// 用处理后的对象替换原有数组
|
||||||
|
cleanedBest := cleanJSONSchemaRecursive(best)
|
||||||
|
schemaMap["items"] = cleanedBest
|
||||||
|
} else {
|
||||||
|
cleanJSONSchemaRecursive(items)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 遍历所有值递归
|
||||||
|
for _, v := range schemaMap {
|
||||||
|
if _, isMap := v.(map[string]any); isMap {
|
||||||
|
cleanJSONSchemaRecursive(v)
|
||||||
|
} else if arr, isArr := v.([]any); isArr {
|
||||||
|
for _, item := range arr {
|
||||||
|
cleanJSONSchemaRecursive(item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
|
||||||
|
var unionArray []any
|
||||||
|
typeStr, _ := schemaMap["type"].(string)
|
||||||
|
if typeStr == "" || typeStr == "object" {
|
||||||
|
if anyOf, ok := schemaMap["anyOf"].([]any); ok {
|
||||||
|
unionArray = anyOf
|
||||||
|
} else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
|
||||||
|
unionArray = oneOf
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(unionArray) > 0 {
|
||||||
|
if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
|
||||||
|
if bestMap, ok := bestBranch.(map[string]any); ok {
|
||||||
|
// 合并分支内容
|
||||||
|
for k, v := range bestMap {
|
||||||
|
if k == "properties" {
|
||||||
|
targetProps, _ := schemaMap["properties"].(map[string]any)
|
||||||
|
if targetProps == nil {
|
||||||
|
targetProps = make(map[string]any)
|
||||||
|
schemaMap["properties"] = targetProps
|
||||||
|
}
|
||||||
|
if sourceProps, ok := v.(map[string]any); ok {
|
||||||
|
for pk, pv := range sourceProps {
|
||||||
|
if _, exists := targetProps[pk]; !exists {
|
||||||
|
targetProps[pk] = deepCopy(pv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if k == "required" {
|
||||||
|
targetReq, _ := schemaMap["required"].([]any)
|
||||||
|
if sourceReq, ok := v.([]any); ok {
|
||||||
|
for _, rv := range sourceReq {
|
||||||
|
// 简单的去重添加
|
||||||
|
exists := false
|
||||||
|
for _, tr := range targetReq {
|
||||||
|
if tr == rv {
|
||||||
|
exists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
targetReq = append(targetReq, rv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
schemaMap["required"] = targetReq
|
||||||
|
}
|
||||||
|
} else if _, exists := schemaMap[k]; !exists {
|
||||||
|
schemaMap[k] = deepCopy(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
|
||||||
|
looksLikeSchema := hasKey(schemaMap, "type") ||
|
||||||
|
hasKey(schemaMap, "properties") ||
|
||||||
|
hasKey(schemaMap, "items") ||
|
||||||
|
hasKey(schemaMap, "enum") ||
|
||||||
|
hasKey(schemaMap, "anyOf") ||
|
||||||
|
hasKey(schemaMap, "oneOf") ||
|
||||||
|
hasKey(schemaMap, "allOf")
|
||||||
|
|
||||||
|
if looksLikeSchema {
|
||||||
|
// 4. [ROBUST] 约束迁移
|
||||||
|
migrateConstraints(schemaMap)
|
||||||
|
|
||||||
|
// 5. [CRITICAL] 白名单过滤
|
||||||
|
allowedFields := map[string]bool{
|
||||||
|
"type": true,
|
||||||
|
"description": true,
|
||||||
|
"properties": true,
|
||||||
|
"required": true,
|
||||||
|
"items": true,
|
||||||
|
"enum": true,
|
||||||
|
"title": true,
|
||||||
|
}
|
||||||
|
for k := range schemaMap {
|
||||||
|
if !allowedFields[k] {
|
||||||
|
delete(schemaMap, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. [SAFETY] 处理空 Object
|
||||||
|
if t, _ := schemaMap["type"].(string); t == "object" {
|
||||||
|
hasProps := false
|
||||||
|
if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
|
||||||
|
hasProps = true
|
||||||
|
}
|
||||||
|
if !hasProps {
|
||||||
|
schemaMap["properties"] = map[string]any{
|
||||||
|
"reason": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Reason for calling this tool",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
schemaMap["required"] = []any{"reason"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. [SAFETY] Required 字段对齐
|
||||||
|
if props, ok := schemaMap["properties"].(map[string]any); ok {
|
||||||
|
if req, ok := schemaMap["required"].([]any); ok {
|
||||||
|
var validReq []any
|
||||||
|
for _, r := range req {
|
||||||
|
if rStr, ok := r.(string); ok {
|
||||||
|
if _, exists := props[rStr]; exists {
|
||||||
|
validReq = append(validReq, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(validReq) > 0 {
|
||||||
|
schemaMap["required"] = validReq
|
||||||
|
} else {
|
||||||
|
delete(schemaMap, "required")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8. 处理 type 字段 (Lowercase + Nullable 提取)
|
||||||
|
isEffectivelyNullable := false
|
||||||
|
if typeVal, exists := schemaMap["type"]; exists {
|
||||||
|
var selectedType string
|
||||||
|
switch v := typeVal.(type) {
|
||||||
|
case string:
|
||||||
|
lower := strings.ToLower(v)
|
||||||
|
if lower == "null" {
|
||||||
|
isEffectivelyNullable = true
|
||||||
|
selectedType = "string" // fallback
|
||||||
|
} else {
|
||||||
|
selectedType = lower
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
// ["string", "null"]
|
||||||
|
for _, t := range v {
|
||||||
|
if ts, ok := t.(string); ok {
|
||||||
|
lower := strings.ToLower(ts)
|
||||||
|
if lower == "null" {
|
||||||
|
isEffectivelyNullable = true
|
||||||
|
} else if selectedType == "" {
|
||||||
|
selectedType = lower
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if selectedType == "" {
|
||||||
|
selectedType = "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
schemaMap["type"] = selectedType
|
||||||
|
} else {
|
||||||
|
// 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
|
||||||
|
// 如果没有 type,但有 properties,补一个
|
||||||
|
if hasKey(schemaMap, "properties") {
|
||||||
|
schemaMap["type"] = "object"
|
||||||
|
} else {
|
||||||
|
// 默认为 string ? or object? Gemini 通常需要明确 type
|
||||||
|
schemaMap["type"] = "object"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isEffectivelyNullable {
|
||||||
|
desc, _ := schemaMap["description"].(string)
|
||||||
|
if !strings.Contains(desc, "nullable") {
|
||||||
|
if desc != "" {
|
||||||
|
desc += " "
|
||||||
|
}
|
||||||
|
desc += "(nullable)"
|
||||||
|
schemaMap["description"] = desc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9. Enum 值强制转字符串
|
||||||
|
if enumVals, ok := schemaMap["enum"].([]any); ok {
|
||||||
|
hasNonString := false
|
||||||
|
for i, val := range enumVals {
|
||||||
|
if _, isStr := val.(string); !isStr {
|
||||||
|
hasNonString = true
|
||||||
|
if val == nil {
|
||||||
|
enumVals[i] = "null"
|
||||||
|
} else {
|
||||||
|
enumVals[i] = fmt.Sprintf("%v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we mandated string values, we must ensure type is string
|
||||||
|
if hasNonString {
|
||||||
|
schemaMap["type"] = "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return schemaMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasKey(m map[string]any, k string) bool {
|
||||||
|
_, ok := m[k]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateConstraints(m map[string]any) {
|
||||||
|
constraints := []struct {
|
||||||
|
key string
|
||||||
|
label string
|
||||||
|
}{
|
||||||
|
{"minLength", "minLen"},
|
||||||
|
{"maxLength", "maxLen"},
|
||||||
|
{"pattern", "pattern"},
|
||||||
|
{"minimum", "min"},
|
||||||
|
{"maximum", "max"},
|
||||||
|
{"multipleOf", "multipleOf"},
|
||||||
|
{"exclusiveMinimum", "exclMin"},
|
||||||
|
{"exclusiveMaximum", "exclMax"},
|
||||||
|
{"minItems", "minItems"},
|
||||||
|
{"maxItems", "maxItems"},
|
||||||
|
{"propertyNames", "propertyNames"},
|
||||||
|
{"format", "format"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var hints []string
|
||||||
|
for _, c := range constraints {
|
||||||
|
if val, ok := m[c.key]; ok && val != nil {
|
||||||
|
hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hints) > 0 {
|
||||||
|
suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
|
||||||
|
desc, _ := m["description"].(string)
|
||||||
|
if !strings.Contains(desc, suffix) {
|
||||||
|
m["description"] = desc + suffix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeAllOf 合并 allOf
|
||||||
|
func mergeAllOf(m map[string]any) {
|
||||||
|
allOf, ok := m["allOf"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(m, "allOf")
|
||||||
|
|
||||||
|
mergedProps := make(map[string]any)
|
||||||
|
mergedReq := make(map[string]bool)
|
||||||
|
otherFields := make(map[string]any)
|
||||||
|
|
||||||
|
for _, sub := range allOf {
|
||||||
|
if subMap, ok := sub.(map[string]any); ok {
|
||||||
|
// Props
|
||||||
|
if props, ok := subMap["properties"].(map[string]any); ok {
|
||||||
|
for k, v := range props {
|
||||||
|
mergedProps[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Required
|
||||||
|
if reqs, ok := subMap["required"].([]any); ok {
|
||||||
|
for _, r := range reqs {
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
mergedReq[s] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Others
|
||||||
|
for k, v := range subMap {
|
||||||
|
if k != "properties" && k != "required" && k != "allOf" {
|
||||||
|
if _, exists := otherFields[k]; !exists {
|
||||||
|
otherFields[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply
|
||||||
|
for k, v := range otherFields {
|
||||||
|
if _, exists := m[k]; !exists {
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(mergedProps) > 0 {
|
||||||
|
existProps, _ := m["properties"].(map[string]any)
|
||||||
|
if existProps == nil {
|
||||||
|
existProps = make(map[string]any)
|
||||||
|
m["properties"] = existProps
|
||||||
|
}
|
||||||
|
for k, v := range mergedProps {
|
||||||
|
if _, exists := existProps[k]; !exists {
|
||||||
|
existProps[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(mergedReq) > 0 {
|
||||||
|
existReq, _ := m["required"].([]any)
|
||||||
|
var validReqs []any
|
||||||
|
for _, r := range existReq {
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
validReqs = append(validReqs, s)
|
||||||
|
delete(mergedReq, s) // already exists
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// append new
|
||||||
|
for r := range mergedReq {
|
||||||
|
validReqs = append(validReqs, r)
|
||||||
|
}
|
||||||
|
m["required"] = validReqs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
|
||||||
|
func extractBestSchemaFromUnion(unionArray []any) any {
|
||||||
|
var bestOption any
|
||||||
|
bestScore := -1
|
||||||
|
|
||||||
|
for _, item := range unionArray {
|
||||||
|
score := scoreSchemaOption(item)
|
||||||
|
if score > bestScore {
|
||||||
|
bestScore = score
|
||||||
|
bestOption = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestOption
|
||||||
|
}
|
||||||
|
|
||||||
|
func scoreSchemaOption(val any) int {
|
||||||
|
m, ok := val.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
typeStr, _ := m["type"].(string)
|
||||||
|
|
||||||
|
if hasKey(m, "properties") || typeStr == "object" {
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
if hasKey(m, "items") || typeStr == "array" {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
if typeStr != "" && typeStr != "null" {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
|
||||||
|
func DeepCleanUndefined(value any) {
|
||||||
|
if value == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch v := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
for k, val := range v {
|
||||||
|
if s, ok := val.(string); ok && s == "[undefined]" {
|
||||||
|
delete(v, k)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
DeepCleanUndefined(val)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, val := range v {
|
||||||
|
DeepCleanUndefined(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
|||||||
// 检查是否结束
|
// 检查是否结束
|
||||||
if len(geminiResp.Candidates) > 0 {
|
if len(geminiResp.Candidates) > 0 {
|
||||||
finishReason := geminiResp.Candidates[0].FinishReason
|
finishReason := geminiResp.Candidates[0].FinishReason
|
||||||
|
if finishReason == "MALFORMED_FUNCTION_CALL" {
|
||||||
|
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
|
||||||
|
if geminiResp.Candidates[0].Content != nil {
|
||||||
|
if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
|
||||||
|
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if finishReason != "" {
|
if finishReason != "" {
|
||||||
_, _ = result.Write(p.emitFinish(finishReason))
|
_, _ = result.Write(p.emitFinish(finishReason))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,20 +13,26 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Claude OAuth Constants (from CRS project)
|
// Claude OAuth Constants
|
||||||
const (
|
const (
|
||||||
// OAuth Client ID for Claude
|
// OAuth Client ID for Claude
|
||||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||||
|
|
||||||
// OAuth endpoints
|
// OAuth endpoints
|
||||||
AuthorizeURL = "https://claude.ai/oauth/authorize"
|
AuthorizeURL = "https://claude.ai/oauth/authorize"
|
||||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
TokenURL = "https://platform.claude.com/v1/oauth/token"
|
||||||
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
|
RedirectURI = "https://platform.claude.com/oauth/code/callback"
|
||||||
|
|
||||||
// Scopes
|
// Scopes - Browser URL (includes org:create_api_key for user authorization)
|
||||||
ScopeProfile = "user:profile"
|
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||||
|
// Scopes - Internal API call (org:create_api_key not supported in API)
|
||||||
|
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||||
|
// Scopes - Setup token (inference only)
|
||||||
ScopeInference = "user:inference"
|
ScopeInference = "user:inference"
|
||||||
|
|
||||||
|
// Code Verifier character set (RFC 7636 compliant)
|
||||||
|
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||||
|
|
||||||
// Session TTL
|
// Session TTL
|
||||||
SessionTTL = 30 * time.Minute
|
SessionTTL = 30 * time.Minute
|
||||||
)
|
)
|
||||||
@@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore {
|
|||||||
sessions: make(map[string]*OAuthSession),
|
sessions: make(map[string]*OAuthSession),
|
||||||
stopCh: make(chan struct{}),
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
// Start cleanup goroutine
|
|
||||||
go store.cleanup()
|
go store.cleanup()
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
@@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
// Check if expired
|
|
||||||
if time.Since(session.CreatedAt) > SessionTTL {
|
if time.Since(session.CreatedAt) > SessionTTL {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) {
|
|||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateState generates a random state string for OAuth
|
// GenerateState generates a random state string for OAuth (base64url encoded)
|
||||||
func GenerateState() (string, error) {
|
func GenerateState() (string, error) {
|
||||||
bytes, err := GenerateRandomBytes(32)
|
bytes, err := GenerateRandomBytes(32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return hex.EncodeToString(bytes), nil
|
return base64URLEncode(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSessionID generates a unique session ID
|
// GenerateSessionID generates a unique session ID
|
||||||
@@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) {
|
|||||||
return hex.EncodeToString(bytes), nil
|
return hex.EncodeToString(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
|
// GenerateCodeVerifier generates a PKCE code verifier using character set method
|
||||||
func GenerateCodeVerifier() (string, error) {
|
func GenerateCodeVerifier() (string, error) {
|
||||||
bytes, err := GenerateRandomBytes(32)
|
const targetLen = 32
|
||||||
if err != nil {
|
charsetLen := len(codeVerifierCharset)
|
||||||
return "", err
|
limit := 256 - (256 % charsetLen)
|
||||||
|
|
||||||
|
result := make([]byte, 0, targetLen)
|
||||||
|
randBuf := make([]byte, targetLen*2)
|
||||||
|
|
||||||
|
for len(result) < targetLen {
|
||||||
|
if _, err := rand.Read(randBuf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
for _, b := range randBuf {
|
||||||
|
if int(b) < limit {
|
||||||
|
result = append(result, codeVerifierCharset[int(b)%charsetLen])
|
||||||
|
if len(result) >= targetLen {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return base64URLEncode(bytes), nil
|
|
||||||
|
return base64URLEncode(result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||||
@@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string {
|
|||||||
// base64URLEncode encodes bytes to base64url without padding
|
// base64URLEncode encodes bytes to base64url without padding
|
||||||
func base64URLEncode(data []byte) string {
|
func base64URLEncode(data []byte) string {
|
||||||
encoded := base64.URLEncoding.EncodeToString(data)
|
encoded := base64.URLEncoding.EncodeToString(data)
|
||||||
// Remove padding
|
|
||||||
return strings.TrimRight(encoded, "=")
|
return strings.TrimRight(encoded, "=")
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildAuthorizationURL builds the OAuth authorization URL
|
// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order
|
||||||
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
|
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
|
||||||
params := url.Values{}
|
encodedRedirectURI := url.QueryEscape(RedirectURI)
|
||||||
params.Set("response_type", "code")
|
encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+")
|
||||||
params.Set("client_id", ClientID)
|
|
||||||
params.Set("redirect_uri", RedirectURI)
|
|
||||||
params.Set("scope", scope)
|
|
||||||
params.Set("state", state)
|
|
||||||
params.Set("code_challenge", codeChallenge)
|
|
||||||
params.Set("code_challenge_method", "S256")
|
|
||||||
|
|
||||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s",
|
||||||
}
|
AuthorizeURL,
|
||||||
|
ClientID,
|
||||||
// TokenRequest represents the token exchange request body
|
encodedRedirectURI,
|
||||||
type TokenRequest struct {
|
encodedScope,
|
||||||
GrantType string `json:"grant_type"`
|
codeChallenge,
|
||||||
ClientID string `json:"client_id"`
|
state,
|
||||||
Code string `json:"code"`
|
)
|
||||||
RedirectURI string `json:"redirect_uri"`
|
|
||||||
CodeVerifier string `json:"code_verifier"`
|
|
||||||
State string `json:"state"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenResponse represents the token response from OAuth provider
|
// TokenResponse represents the token response from OAuth provider
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
Scope string `json:"scope,omitempty"`
|
Scope string `json:"scope,omitempty"`
|
||||||
// Organization and Account info from OAuth response
|
|
||||||
Organization *OrgInfo `json:"organization,omitempty"`
|
Organization *OrgInfo `json:"organization,omitempty"`
|
||||||
Account *AccountInfo `json:"account,omitempty"`
|
Account *AccountInfo `json:"account,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -205,33 +215,6 @@ type OrgInfo struct {
|
|||||||
|
|
||||||
// AccountInfo represents account info from OAuth response
|
// AccountInfo represents account info from OAuth response
|
||||||
type AccountInfo struct {
|
type AccountInfo struct {
|
||||||
UUID string `json:"uuid"`
|
UUID string `json:"uuid"`
|
||||||
}
|
EmailAddress string `json:"email_address"`
|
||||||
|
|
||||||
// RefreshTokenRequest represents the refresh token request
|
|
||||||
type RefreshTokenRequest struct {
|
|
||||||
GrantType string `json:"grant_type"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildTokenRequest creates a token exchange request
|
|
||||||
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
|
|
||||||
return &TokenRequest{
|
|
||||||
GrantType: "authorization_code",
|
|
||||||
ClientID: ClientID,
|
|
||||||
Code: code,
|
|
||||||
RedirectURI: RedirectURI,
|
|
||||||
CodeVerifier: codeVerifier,
|
|
||||||
State: state,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildRefreshTokenRequest creates a refresh token request
|
|
||||||
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
|
|
||||||
return &RefreshTokenRequest{
|
|
||||||
GrantType: "refresh_token",
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
ClientID: ClientID,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
package response
|
package response
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@@ -74,6 +75,12 @@ func ErrorFrom(c *gin.Context, err error) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
statusCode, status := infraerrors.ToHTTP(err)
|
statusCode, status := infraerrors.ToHTTP(err)
|
||||||
|
|
||||||
|
// Log internal errors with full details for debugging
|
||||||
|
if statusCode >= 500 && c.Request != nil {
|
||||||
|
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
278
backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||||
|
//
|
||||||
|
// Integration tests for verifying TLS fingerprint correctness.
|
||||||
|
// These tests make actual network requests to external services and should be run manually.
|
||||||
|
//
|
||||||
|
// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||||
|
package tlsfingerprint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// skipIfExternalServiceUnavailable checks if the external service is available.
|
||||||
|
// If not, it skips the test instead of failing.
|
||||||
|
func skipIfExternalServiceUnavailable(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err != nil {
|
||||||
|
// Check for common network/TLS errors that indicate external service issues
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "certificate has expired") ||
|
||||||
|
strings.Contains(errStr, "certificate is not yet valid") ||
|
||||||
|
strings.Contains(errStr, "connection refused") ||
|
||||||
|
strings.Contains(errStr, "no such host") ||
|
||||||
|
strings.Contains(errStr, "network is unreachable") ||
|
||||||
|
strings.Contains(errStr, "timeout") {
|
||||||
|
t.Skipf("skipping test: external service unavailable: %v", err)
|
||||||
|
}
|
||||||
|
t.Fatalf("failed to get fingerprint: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||||
|
// This test uses tls.peet.ws to verify the fingerprint.
|
||||||
|
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||||
|
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||||
|
func TestJA3Fingerprint(t *testing.T) {
|
||||||
|
// Skip if network is unavailable or if running in short mode
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
profile := &Profile{
|
||||||
|
Name: "Claude CLI Test",
|
||||||
|
EnableGREASE: false,
|
||||||
|
}
|
||||||
|
dialer := NewDialer(profile, nil)
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialTLSContext: dialer.DialTLSContext,
|
||||||
|
},
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use tls.peet.ws fingerprint detection API
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
skipIfExternalServiceUnavailable(t, err)
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fpResp FingerprintResponse
|
||||||
|
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||||
|
t.Logf("Response body: %s", string(body))
|
||||||
|
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log all fingerprint information
|
||||||
|
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||||
|
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||||
|
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||||
|
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||||
|
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||||
|
|
||||||
|
// Verify JA3 hash matches expected value
|
||||||
|
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||||
|
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||||
|
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||||
|
} else {
|
||||||
|
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify JA4 fingerprint
|
||||||
|
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||||
|
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||||
|
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||||
|
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||||
|
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||||
|
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||||
|
} else {
|
||||||
|
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||||
|
// d = domain (SNI present), i = IP (no SNI)
|
||||||
|
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||||
|
expectedJA4Prefix := "t13d5911h1"
|
||||||
|
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||||
|
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||||
|
} else {
|
||||||
|
// Also accept 'i' variant for IP connections
|
||||||
|
altPrefix := "t13i5911h1"
|
||||||
|
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||||
|
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||||
|
} else {
|
||||||
|
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||||
|
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||||
|
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||||
|
} else {
|
||||||
|
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify extension list (should be 11 extensions including SNI)
|
||||||
|
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||||
|
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||||
|
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||||
|
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||||
|
} else {
|
||||||
|
t.Logf("Warning: JA3 extension list may differ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProfileExpectation defines expected fingerprint values for a profile.
|
||||||
|
type TestProfileExpectation struct {
|
||||||
|
Profile *Profile
|
||||||
|
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
|
||||||
|
ExpectedJA4 string // Expected full JA4 (empty = don't check)
|
||||||
|
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
|
||||||
|
// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
|
||||||
|
func TestAllProfiles(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define all profiles to test with their expected fingerprints
|
||||||
|
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
|
||||||
|
profiles := []TestProfileExpectation{
|
||||||
|
{
|
||||||
|
// Linux x64 Node.js v22.17.1
|
||||||
|
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||||
|
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
|
||||||
|
Profile: &Profile{
|
||||||
|
Name: "linux_x64_node_v22171",
|
||||||
|
EnableGREASE: false,
|
||||||
|
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||||
|
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||||
|
PointFormats: []uint8{0, 1, 2},
|
||||||
|
},
|
||||||
|
JA4CipherHash: "a33745022dd6", // stable part
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// MacOS arm64 Node.js v22.18.0
|
||||||
|
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
|
||||||
|
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
|
||||||
|
Profile: &Profile{
|
||||||
|
Name: "macos_arm64_node_v22180",
|
||||||
|
EnableGREASE: false,
|
||||||
|
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||||
|
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||||
|
PointFormats: []uint8{0, 1, 2},
|
||||||
|
},
|
||||||
|
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range profiles {
|
||||||
|
tc := tc // capture range variable
|
||||||
|
t.Run(tc.Profile.Name, func(t *testing.T) {
|
||||||
|
fp := fetchFingerprint(t, tc.Profile)
|
||||||
|
if fp == nil {
|
||||||
|
return // fetchFingerprint already called t.Fatal
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Profile: %s", tc.Profile.Name)
|
||||||
|
t.Logf(" JA3: %s", fp.JA3)
|
||||||
|
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
|
||||||
|
t.Logf(" JA4: %s", fp.JA4)
|
||||||
|
t.Logf(" PeetPrint: %s", fp.PeetPrint)
|
||||||
|
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
|
||||||
|
|
||||||
|
// Verify expectations
|
||||||
|
if tc.ExpectedJA3 != "" {
|
||||||
|
if fp.JA3Hash == tc.ExpectedJA3 {
|
||||||
|
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
|
||||||
|
} else {
|
||||||
|
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.ExpectedJA4 != "" {
|
||||||
|
if fp.JA4 == tc.ExpectedJA4 {
|
||||||
|
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
|
||||||
|
} else {
|
||||||
|
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check JA4 cipher hash (stable middle part)
|
||||||
|
// JA4 format: prefix_cipherHash_extHash
|
||||||
|
if tc.JA4CipherHash != "" {
|
||||||
|
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
|
||||||
|
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
|
||||||
|
} else {
|
||||||
|
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
|
||||||
|
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dialer := NewDialer(profile, nil)
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialTLSContext: dialer.DialTLSContext,
|
||||||
|
},
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
skipIfExternalServiceUnavailable(t, err)
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read response: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var fpResp FingerprintResponse
|
||||||
|
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||||
|
t.Logf("Response body: %s", string(body))
|
||||||
|
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &fpResp.TLS
|
||||||
|
}
|
||||||
@@ -1,21 +1,16 @@
|
|||||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||||
//
|
//
|
||||||
// Integration tests for verifying TLS fingerprint correctness.
|
// Unit tests for TLS fingerprint dialer.
|
||||||
// These tests make actual network requests and should be run manually.
|
// Integration tests that require external network are in dialer_integration_test.go
|
||||||
|
// and require the 'integration' build tag.
|
||||||
//
|
//
|
||||||
// Run with: go test -v ./internal/pkg/tlsfingerprint/...
|
// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
|
||||||
// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
|
// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
|
||||||
package tlsfingerprint
|
package tlsfingerprint
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||||
@@ -36,148 +31,6 @@ type TLSInfo struct {
|
|||||||
SessionID string `json:"session_id"`
|
SessionID string `json:"session_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
|
|
||||||
func TestDialerBasicConnection(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skipping network test in short mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a dialer with default profile
|
|
||||||
profile := &Profile{
|
|
||||||
Name: "Test Profile",
|
|
||||||
EnableGREASE: false,
|
|
||||||
}
|
|
||||||
dialer := NewDialer(profile, nil)
|
|
||||||
|
|
||||||
// Create HTTP client with custom TLS dialer
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
DialTLSContext: dialer.DialTLSContext,
|
|
||||||
},
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make a request to a known HTTPS endpoint
|
|
||||||
resp, err := client.Get("https://www.google.com")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to connect: %v", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
|
||||||
// This test uses tls.peet.ws to verify the fingerprint.
|
|
||||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
|
||||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
|
||||||
func TestJA3Fingerprint(t *testing.T) {
|
|
||||||
// Skip if network is unavailable or if running in short mode
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skipping integration test in short mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
profile := &Profile{
|
|
||||||
Name: "Claude CLI Test",
|
|
||||||
EnableGREASE: false,
|
|
||||||
}
|
|
||||||
dialer := NewDialer(profile, nil)
|
|
||||||
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
DialTLSContext: dialer.DialTLSContext,
|
|
||||||
},
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use tls.peet.ws fingerprint detection API
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get fingerprint: %v", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to read response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var fpResp FingerprintResponse
|
|
||||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
|
||||||
t.Logf("Response body: %s", string(body))
|
|
||||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log all fingerprint information
|
|
||||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
|
||||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
|
||||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
|
||||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
|
||||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
|
||||||
|
|
||||||
// Verify JA3 hash matches expected value
|
|
||||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
|
||||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
|
||||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
|
||||||
} else {
|
|
||||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify JA4 fingerprint
|
|
||||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
|
||||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
|
||||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
|
||||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
|
||||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
|
||||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
|
||||||
} else {
|
|
||||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
|
||||||
// d = domain (SNI present), i = IP (no SNI)
|
|
||||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
|
||||||
expectedJA4Prefix := "t13d5911h1"
|
|
||||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
|
||||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
|
||||||
} else {
|
|
||||||
// Also accept 'i' variant for IP connections
|
|
||||||
altPrefix := "t13i5911h1"
|
|
||||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
|
||||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
|
||||||
} else {
|
|
||||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
|
||||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
|
||||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
|
||||||
} else {
|
|
||||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify extension list (should be 11 extensions including SNI)
|
|
||||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
|
||||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
|
||||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
|
||||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
|
||||||
} else {
|
|
||||||
t.Logf("Warning: JA3 extension list may differ")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDialerWithProfile tests that different profiles produce different fingerprints.
|
// TestDialerWithProfile tests that different profiles produce different fingerprints.
|
||||||
func TestDialerWithProfile(t *testing.T) {
|
func TestDialerWithProfile(t *testing.T) {
|
||||||
// Create two dialers with different profiles
|
// Create two dialers with different profiles
|
||||||
|
|||||||
@@ -39,9 +39,15 @@ import (
|
|||||||
// 设计说明:
|
// 设计说明:
|
||||||
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
// - client: Ent 客户端,用于类型安全的 ORM 操作
|
||||||
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
// - sql: 原生 SQL 执行器,用于复杂查询和批量操作
|
||||||
|
// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照
|
||||||
type accountRepository struct {
|
type accountRepository struct {
|
||||||
client *dbent.Client // Ent ORM 客户端
|
client *dbent.Client // Ent ORM 客户端
|
||||||
sql sqlExecutor // 原生 SQL 执行接口
|
sql sqlExecutor // 原生 SQL 执行接口
|
||||||
|
// schedulerCache 用于在账号状态变更时主动同步快照到缓存,
|
||||||
|
// 确保粘性会话能及时感知账号不可用状态。
|
||||||
|
// Used to proactively sync account snapshot to cache when status changes,
|
||||||
|
// ensuring sticky sessions can promptly detect unavailable accounts.
|
||||||
|
schedulerCache service.SchedulerCache
|
||||||
}
|
}
|
||||||
|
|
||||||
type tempUnschedSnapshot struct {
|
type tempUnschedSnapshot struct {
|
||||||
@@ -51,14 +57,14 @@ type tempUnschedSnapshot struct {
|
|||||||
|
|
||||||
// NewAccountRepository 创建账户仓储实例。
|
// NewAccountRepository 创建账户仓储实例。
|
||||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
|
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||||
return newAccountRepositoryWithSQL(client, sqlDB)
|
return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
// newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。
|
||||||
// 这种设计便于单元测试时注入 mock 对象。
|
// 这种设计便于单元测试时注入 mock 对象。
|
||||||
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository {
|
func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository {
|
||||||
return &accountRepository{client: client, sql: sqlq}
|
return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
|
||||||
@@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||||
}
|
}
|
||||||
|
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
|||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||||
}
|
}
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。
|
||||||
|
// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用,
|
||||||
|
// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。
|
||||||
|
//
|
||||||
|
// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache
|
||||||
|
// when account status changes. Called when account is set to error, disabled,
|
||||||
|
// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session
|
||||||
|
// logic can promptly detect the latest account state and avoid using unavailable accounts.
|
||||||
|
func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
|
||||||
|
if r == nil || r.schedulerCache == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account, err := r.GetByID(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||||
|
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||||
_, err := r.client.Account.Update().
|
_, err := r.client.Account.Update().
|
||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
@@ -873,6 +905,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
|||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||||
}
|
}
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -992,6 +1025,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
|||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||||
}
|
}
|
||||||
|
if !schedulable {
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1146,6 +1182,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
|||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||||
}
|
}
|
||||||
|
shouldSync := false
|
||||||
|
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||||
|
shouldSync = true
|
||||||
|
}
|
||||||
|
if updates.Schedulable != nil && !*updates.Schedulable {
|
||||||
|
shouldSync = true
|
||||||
|
}
|
||||||
|
if shouldSync {
|
||||||
|
for _, id := range ids {
|
||||||
|
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,11 +21,56 @@ type AccountRepoSuite struct {
|
|||||||
repo *accountRepository
|
repo *accountRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type schedulerCacheRecorder struct {
|
||||||
|
setAccounts []*service.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||||
|
s.setAccounts = append(s.setAccounts, account)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AccountRepoSuite) SetupTest() {
|
func (s *AccountRepoSuite) SetupTest() {
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
tx := testEntTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.client = tx.Client()
|
s.client = tx.Client()
|
||||||
s.repo = newAccountRepositoryWithSQL(s.client, tx)
|
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountRepoSuite(t *testing.T) {
|
func TestAccountRepoSuite(t *testing.T) {
|
||||||
@@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() {
|
|||||||
s.Require().Equal("updated", got.Name)
|
s.Require().Equal("updated", got.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
|
||||||
|
account.Status = service.StatusDisabled
|
||||||
|
err := s.repo.Update(s.ctx, account)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||||
|
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AccountRepoSuite) TestDelete() {
|
func (s *AccountRepoSuite) TestDelete() {
|
||||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||||
|
|
||||||
@@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
|||||||
// 每个 case 重新获取隔离资源
|
// 每个 case 重新获取隔离资源
|
||||||
tx := testEntTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
client := tx.Client()
|
client := tx.Client()
|
||||||
repo := newAccountRepositoryWithSQL(client, tx)
|
repo := newAccountRepositoryWithSQL(client, tx, nil)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
tt.setup(client)
|
tt.setup(client)
|
||||||
@@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
|||||||
|
|
||||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
|
||||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||||
|
|
||||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().False(got.Schedulable)
|
s.Require().False(got.Schedulable)
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||||
|
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
|
||||||
|
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
|
||||||
|
disabled := service.StatusDisabled
|
||||||
|
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
|
||||||
|
Status: &disabled,
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(int64(2), rows)
|
||||||
|
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 2)
|
||||||
|
ids := map[int64]struct{}{}
|
||||||
|
for _, acc := range cacheRecorder.setAccounts {
|
||||||
|
ids[acc.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
s.Require().Contains(ids, account1.ID)
|
||||||
|
s.Require().Contains(ids, account2.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||||
|
|||||||
95
backend/internal/repository/aes_encryptor.go
Normal file
95
backend/internal/repository/aes_encryptor.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AESEncryptor implements SecretEncryptor using AES-256-GCM
|
||||||
|
type AESEncryptor struct {
|
||||||
|
key []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAESEncryptor creates a new AES encryptor
|
||||||
|
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
|
||||||
|
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(key) != 32 {
|
||||||
|
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AESEncryptor{key: key}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt encrypts plaintext using AES-256-GCM
|
||||||
|
// Output format: base64(nonce + ciphertext + tag)
|
||||||
|
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
|
||||||
|
block, err := aes.NewCipher(e.key)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create gcm: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random nonce
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
|
return "", fmt.Errorf("generate nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the plaintext
|
||||||
|
// Seal appends the ciphertext and tag to the nonce
|
||||||
|
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||||
|
|
||||||
|
// Encode as base64
|
||||||
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts ciphertext using AES-256-GCM
|
||||||
|
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||||
|
// Decode from base64
|
||||||
|
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode base64: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(e.key)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create gcm: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := gcm.NonceSize()
|
||||||
|
if len(data) < nonceSize {
|
||||||
|
return "", fmt.Errorf("ciphertext too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract nonce and ciphertext
|
||||||
|
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
|
||||||
|
|
||||||
|
// Decrypt
|
||||||
|
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decrypt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(plaintext), nil
|
||||||
|
}
|
||||||
@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &service.User{
|
return &service.User{
|
||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Username: u.Username,
|
Username: u.Username,
|
||||||
Notes: u.Notes,
|
Notes: u.Notes,
|
||||||
PasswordHash: u.PasswordHash,
|
PasswordHash: u.PasswordHash,
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
Balance: u.Balance,
|
Balance: u.Balance,
|
||||||
Concurrency: u.Concurrency,
|
Concurrency: u.Concurrency,
|
||||||
Status: u.Status,
|
Status: u.Status,
|
||||||
CreatedAt: u.CreatedAt,
|
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||||
UpdatedAt: u.UpdatedAt,
|
TotpEnabled: u.TotpEnabled,
|
||||||
|
TotpEnabledAt: u.TotpEnabledAt,
|
||||||
|
CreatedAt: u.CreatedAt,
|
||||||
|
UpdatedAt: u.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
|||||||
client := s.clientFactory(proxyURL)
|
client := s.clientFactory(proxyURL)
|
||||||
|
|
||||||
var orgs []struct {
|
var orgs []struct {
|
||||||
UUID string `json:"uuid"`
|
UUID string `json:"uuid"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
|
||||||
}
|
}
|
||||||
|
|
||||||
targetURL := s.baseURL + "/api/organizations"
|
targetURL := s.baseURL + "/api/organizations"
|
||||||
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
|||||||
return "", fmt.Errorf("no organizations found")
|
return "", fmt.Errorf("no organizations found")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
// 如果只有一个组织,直接使用
|
||||||
|
if len(orgs) == 1 {
|
||||||
|
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||||
|
return orgs[0].UUID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
|
||||||
|
for _, org := range orgs {
|
||||||
|
if org.RavenType != nil && *org.RavenType == "team" {
|
||||||
|
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||||
|
org.UUID, org.Name, *org.RavenType)
|
||||||
|
return org.UUID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有 team 类型的组织,使用第一个
|
||||||
|
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||||
return orgs[0].UUID, nil
|
return orgs[0].UUID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,7 +200,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
|||||||
|
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
|
SetHeader("Accept", "application/json, text/plain, */*").
|
||||||
SetHeader("Content-Type", "application/json").
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetHeader("User-Agent", "axios/1.8.4").
|
||||||
SetBody(reqBody).
|
SetBody(reqBody).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(s.tokenURL)
|
Post(s.tokenURL)
|
||||||
@@ -205,8 +225,6 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
|||||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||||
client := s.clientFactory(proxyURL)
|
client := s.clientFactory(proxyURL)
|
||||||
|
|
||||||
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
|
|
||||||
// Anthropic OAuth API 期望 JSON 格式的请求体
|
|
||||||
reqBody := map[string]any{
|
reqBody := map[string]any{
|
||||||
"grant_type": "refresh_token",
|
"grant_type": "refresh_token",
|
||||||
"refresh_token": refreshToken,
|
"refresh_token": refreshToken,
|
||||||
@@ -217,7 +235,9 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
|||||||
|
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
|
SetHeader("Accept", "application/json, text/plain, */*").
|
||||||
SetHeader("Content-Type", "application/json").
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetHeader("User-Agent", "axios/1.8.4").
|
||||||
SetBody(reqBody).
|
SetBody(reqBody).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(s.tokenURL)
|
Post(s.tokenURL)
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
|||||||
s.client.baseURL = "http://in-process"
|
s.client.baseURL = "http://in-process"
|
||||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||||
|
|
||||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
|
||||||
|
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
require.Error(s.T(), err)
|
require.Error(s.T(), err)
|
||||||
|
|||||||
@@ -14,37 +14,82 @@ import (
|
|||||||
|
|
||||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||||
|
|
||||||
|
// 默认 User-Agent,与用户抓包的请求一致
|
||||||
|
const defaultUsageUserAgent = "claude-code/2.1.7"
|
||||||
|
|
||||||
type claudeUsageService struct {
|
type claudeUsageService struct {
|
||||||
usageURL string
|
usageURL string
|
||||||
allowPrivateHosts bool
|
allowPrivateHosts bool
|
||||||
|
httpUpstream service.HTTPUpstream
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
// NewClaudeUsageFetcher 创建 Claude 用量获取服务
|
||||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装
|
||||||
|
func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher {
|
||||||
|
return &claudeUsageService{
|
||||||
|
usageURL: defaultClaudeUsageURL,
|
||||||
|
httpUpstream: httpUpstream,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容)
|
||||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||||
client, err := httpclient.GetClient(httpclient.Options{
|
return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{
|
||||||
ProxyURL: proxyURL,
|
AccessToken: accessToken,
|
||||||
Timeout: 30 * time.Second,
|
ProxyURL: proxyURL,
|
||||||
ValidateResolvedIP: true,
|
|
||||||
AllowPrivateHosts: s.allowPrivateHosts,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
}
|
||||||
client = &http.Client{Timeout: 30 * time.Second}
|
|
||||||
|
// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent
|
||||||
|
func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) {
|
||||||
|
if opts == nil {
|
||||||
|
return nil, fmt.Errorf("options is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 创建请求
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create request failed: %w", err)
|
return nil, fmt.Errorf("create request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
// 设置请求头(与抓包一致,但不设置 Accept-Encoding,让 Go 自动处理压缩)
|
||||||
|
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
|
||||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
// 设置 User-Agent(优先使用缓存的 Fingerprint,否则使用默认值)
|
||||||
if err != nil {
|
userAgent := defaultUsageUserAgent
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" {
|
||||||
|
userAgent = opts.Fingerprint.UserAgent
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
|
|
||||||
|
// 如果启用 TLS 指纹且有 HTTPUpstream,使用 DoWithTLS
|
||||||
|
if opts.EnableTLSFingerprint && s.httpUpstream != nil {
|
||||||
|
// accountConcurrency 传 0 使用默认连接池配置,usage 请求不需要特殊的并发设置
|
||||||
|
resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 不启用 TLS 指纹,使用普通 HTTP 客户端
|
||||||
|
client, err := httpclient.GetClient(httpclient.Options{
|
||||||
|
ProxyURL: opts.ProxyURL,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
ValidateResolvedIP: true,
|
||||||
|
AllowPrivateHosts: s.allowPrivateHosts,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
client = &http.Client{Timeout: 30 * time.Second}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
|||||||
@@ -9,13 +9,27 @@ import (
|
|||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
const verifyCodeKeyPrefix = "verify_code:"
|
const (
|
||||||
|
verifyCodeKeyPrefix = "verify_code:"
|
||||||
|
passwordResetKeyPrefix = "password_reset:"
|
||||||
|
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
||||||
|
)
|
||||||
|
|
||||||
// verifyCodeKey generates the Redis key for email verification code.
|
// verifyCodeKey generates the Redis key for email verification code.
|
||||||
func verifyCodeKey(email string) string {
|
func verifyCodeKey(email string) string {
|
||||||
return verifyCodeKeyPrefix + email
|
return verifyCodeKeyPrefix + email
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// passwordResetKey generates the Redis key for password reset token.
|
||||||
|
func passwordResetKey(email string) string {
|
||||||
|
return passwordResetKeyPrefix + email
|
||||||
|
}
|
||||||
|
|
||||||
|
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||||
|
func passwordResetSentAtKey(email string) string {
|
||||||
|
return passwordResetSentAtKeyPrefix + email
|
||||||
|
}
|
||||||
|
|
||||||
type emailCache struct {
|
type emailCache struct {
|
||||||
rdb *redis.Client
|
rdb *redis.Client
|
||||||
}
|
}
|
||||||
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
|
|||||||
key := verifyCodeKey(email)
|
key := verifyCodeKey(email)
|
||||||
return c.rdb.Del(ctx, key).Err()
|
return c.rdb.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Password reset token methods
|
||||||
|
|
||||||
|
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
|
||||||
|
key := passwordResetKey(email)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var data service.PasswordResetTokenData
|
||||||
|
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
|
||||||
|
key := passwordResetKey(email)
|
||||||
|
val, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||||
|
key := passwordResetKey(email)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Password reset email cooldown methods
|
||||||
|
|
||||||
|
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||||
|
key := passwordResetSentAtKey(email)
|
||||||
|
exists, err := c.rdb.Exists(ctx, key).Result()
|
||||||
|
return err == nil && exists > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||||
|
key := passwordResetSentAtKey(email)
|
||||||
|
return c.rdb.Set(ctx, key, "1", ttl).Err()
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
|
|||||||
key := buildSessionKey(groupID, sessionHash)
|
key := buildSessionKey(groupID, sessionHash)
|
||||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
|
||||||
|
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
|
||||||
|
// 以便下次请求能够重新选择可用账号。
|
||||||
|
//
|
||||||
|
// DeleteSessionAccountID removes the sticky session binding for the given session.
|
||||||
|
// Called when the bound account becomes unavailable (e.g., error status, disabled,
|
||||||
|
// or unschedulable), allowing subsequent requests to select a new available account.
|
||||||
|
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||||
|
key := buildSessionKey(groupID, sessionHash)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|||||||
@@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
|||||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
|
||||||
|
sessionID := "openai:s4"
|
||||||
|
accountID := int64(102)
|
||||||
|
groupID := int64(1)
|
||||||
|
sessionTTL := 1 * time.Minute
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||||
|
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
|
||||||
|
|
||||||
|
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||||
sessionID := "corrupted"
|
sessionID := "corrupted"
|
||||||
groupID := int64(1)
|
groupID := int64(1)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() {
|
|||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
tx := testEntTx(s.T())
|
tx := testEntTx(s.T())
|
||||||
s.client = tx.Client()
|
s.client = tx.Client()
|
||||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
|
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGatewayRoutingSuite(t *testing.T) {
|
func TestGatewayRoutingSuite(t *testing.T) {
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
@@ -38,16 +39,17 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
|||||||
|
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
|
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||||
SetFormDataFromValues(formData).
|
SetFormDataFromValues(formData).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(s.tokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !resp.IsSuccessState() {
|
if !resp.IsSuccessState() {
|
||||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tokenResp, nil
|
return &tokenResp, nil
|
||||||
@@ -66,16 +68,17 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
|||||||
|
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
|
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||||
SetFormDataFromValues(formData).
|
SetFormDataFromValues(formData).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(s.tokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !resp.IsSuccessState() {
|
if !resp.IsSuccessState() {
|
||||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tokenResp, nil
|
return &tokenResp, nil
|
||||||
@@ -84,6 +87,6 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
|||||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||||
return getSharedReqClient(reqClientOptions{
|
return getSharedReqClient(reqClientOptions{
|
||||||
ProxyURL: proxyURL,
|
ProxyURL: proxyURL,
|
||||||
Timeout: 60 * time.Second,
|
Timeout: 120 * time.Second,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
|||||||
require.ErrorContains(s.T(), err, "status 401")
|
require.ErrorContains(s.T(), err, "status 401")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
|
||||||
|
client := NewOpenAIOAuthClient()
|
||||||
|
svc, ok := client.(*openaiOAuthService)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, openai.TokenURL, svc.tokenURL)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ type reqClientOptions struct {
|
|||||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||||
Timeout time.Duration // 请求超时时间
|
Timeout time.Duration // 请求超时时间
|
||||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||||
|
ForceHTTP2 bool // 是否强制使用 HTTP/2
|
||||||
}
|
}
|
||||||
|
|
||||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||||
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := req.C().SetTimeout(opts.Timeout)
|
client := req.C().SetTimeout(opts.Timeout)
|
||||||
|
if opts.ForceHTTP2 {
|
||||||
|
client = client.EnableForceHTTP2()
|
||||||
|
}
|
||||||
if opts.Impersonate {
|
if opts.Impersonate {
|
||||||
client = client.ImpersonateChrome()
|
client = client.ImpersonateChrome()
|
||||||
}
|
}
|
||||||
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildReqClientKey(opts reqClientOptions) string {
|
func buildReqClientKey(opts reqClientOptions) string {
|
||||||
return fmt.Sprintf("%s|%s|%t",
|
return fmt.Sprintf("%s|%s|%t|%t",
|
||||||
strings.TrimSpace(opts.ProxyURL),
|
strings.TrimSpace(opts.ProxyURL),
|
||||||
opts.Timeout.String(),
|
opts.Timeout.String(),
|
||||||
opts.Impersonate,
|
opts.Impersonate,
|
||||||
|
opts.ForceHTTP2,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
90
backend/internal/repository/req_client_pool_test.go
Normal file
90
backend/internal/repository/req_client_pool_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func forceHTTPVersion(t *testing.T, client *req.Client) string {
|
||||||
|
t.Helper()
|
||||||
|
transport := client.GetTransport()
|
||||||
|
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
|
||||||
|
require.True(t, field.IsValid(), "forceHttpVersion field not found")
|
||||||
|
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
|
||||||
|
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
base := reqClientOptions{
|
||||||
|
ProxyURL: "http://proxy.local:8080",
|
||||||
|
Timeout: time.Second,
|
||||||
|
}
|
||||||
|
clientDefault := getSharedReqClient(base)
|
||||||
|
|
||||||
|
force := base
|
||||||
|
force.ForceHTTP2 = true
|
||||||
|
clientForce := getSharedReqClient(force)
|
||||||
|
|
||||||
|
require.NotSame(t, clientDefault, clientForce)
|
||||||
|
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
opts := reqClientOptions{
|
||||||
|
ProxyURL: "http://proxy.local:8080",
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
first := getSharedReqClient(opts)
|
||||||
|
second := getSharedReqClient(opts)
|
||||||
|
require.Same(t, first, second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
opts := reqClientOptions{
|
||||||
|
ProxyURL: " http://proxy.local:8080 ",
|
||||||
|
Timeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
key := buildReqClientKey(opts)
|
||||||
|
sharedReqClients.Store(key, "invalid")
|
||||||
|
|
||||||
|
client := getSharedReqClient(opts)
|
||||||
|
|
||||||
|
require.NotNil(t, client)
|
||||||
|
loaded, ok := sharedReqClients.Load(key)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.IsType(t, "invalid", loaded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
opts := reqClientOptions{
|
||||||
|
ProxyURL: " http://proxy.local:8080 ",
|
||||||
|
Timeout: 4 * time.Second,
|
||||||
|
Impersonate: true,
|
||||||
|
}
|
||||||
|
client := getSharedReqClient(opts)
|
||||||
|
|
||||||
|
require.NotNil(t, client)
|
||||||
|
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
client := createOpenAIReqClient("http://proxy.local:8080")
|
||||||
|
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
client := createGeminiReqClient("http://proxy.local:8080")
|
||||||
|
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||||
|
}
|
||||||
@@ -58,7 +58,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
|
|||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
if len(ids) == 0 {
|
if len(ids) == 0 {
|
||||||
return []*service.Account{}, true, nil
|
// 空快照视为缓存未命中,触发数据库回退查询
|
||||||
|
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
|
||||||
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := make([]string, 0, len(ids))
|
keys := make([]string, 0, len(ids))
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
|||||||
|
|
||||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
|
||||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
|
||||||
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||||
cache := NewSchedulerCache(rdb)
|
cache := NewSchedulerCache(rdb)
|
||||||
|
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
|
||||||
if len(accountIDs) == 0 {
|
if len(accountIDs) == 0 {
|
||||||
return make(map[int64]int), nil
|
return make(map[int64]int), nil
|
||||||
}
|
}
|
||||||
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
|
|||||||
|
|
||||||
// 使用 pipeline 批量执行
|
// 使用 pipeline 批量执行
|
||||||
pipe := c.rdb.Pipeline()
|
pipe := c.rdb.Pipeline()
|
||||||
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
|
||||||
|
|
||||||
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
|
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
|
||||||
for _, accountID := range accountIDs {
|
for _, accountID := range accountIDs {
|
||||||
key := sessionLimitKey(accountID)
|
key := sessionLimitKey(accountID)
|
||||||
|
// 使用各账号自己的 idleTimeout,如果没有则用默认值
|
||||||
|
idleTimeout := c.defaultIdleTimeout
|
||||||
|
if idleTimeouts != nil {
|
||||||
|
if t, ok := idleTimeouts[accountID]; ok && t > 0 {
|
||||||
|
idleTimeout = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||||
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
|
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
149
backend/internal/repository/totp_cache.go
Normal file
149
backend/internal/repository/totp_cache.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
totpSetupKeyPrefix = "totp:setup:"
|
||||||
|
totpLoginKeyPrefix = "totp:login:"
|
||||||
|
totpAttemptsKeyPrefix = "totp:attempts:"
|
||||||
|
totpAttemptsTTL = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// TotpCache implements service.TotpCache using Redis
|
||||||
|
type TotpCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTotpCache creates a new TOTP cache
|
||||||
|
func NewTotpCache(rdb *redis.Client) service.TotpCache {
|
||||||
|
return &TotpCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSetupSession retrieves a TOTP setup session
|
||||||
|
func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||||
|
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||||
|
if err != nil {
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get setup session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var session service.TotpSetupSession
|
||||||
|
if err := json.Unmarshal(data, &session); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal setup session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSetupSession stores a TOTP setup session
|
||||||
|
func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
|
||||||
|
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||||
|
data, err := json.Marshal(session)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal setup session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||||
|
return fmt.Errorf("set setup session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSetupSession deletes a TOTP setup session
|
||||||
|
func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLoginSession retrieves a TOTP login session
|
||||||
|
func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
|
||||||
|
key := totpLoginKeyPrefix + tempToken
|
||||||
|
data, err := c.rdb.Get(ctx, key).Bytes()
|
||||||
|
if err != nil {
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get login session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var session service.TotpLoginSession
|
||||||
|
if err := json.Unmarshal(data, &session); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal login session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLoginSession stores a TOTP login session
|
||||||
|
func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
|
||||||
|
key := totpLoginKeyPrefix + tempToken
|
||||||
|
data, err := json.Marshal(session)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal login session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
|
||||||
|
return fmt.Errorf("set login session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteLoginSession deletes a TOTP login session
|
||||||
|
func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
|
||||||
|
key := totpLoginKeyPrefix + tempToken
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementVerifyAttempts increments the verify attempt counter
|
||||||
|
func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||||
|
|
||||||
|
// Use pipeline for atomic increment and set TTL
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
incrCmd := pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, totpAttemptsTTL)
|
||||||
|
|
||||||
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
|
return 0, fmt.Errorf("increment verify attempts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := incrCmd.Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get increment result: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(count), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVerifyAttempts gets the current verify attempt count
|
||||||
|
func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||||
|
count, err := c.rdb.Get(ctx, key).Int()
|
||||||
|
if err != nil {
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("get verify attempts: %w", err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearVerifyAttempts clears the verify attempt counter
|
||||||
|
func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
|||||||
dst.CreatedAt = src.CreatedAt
|
dst.CreatedAt = src.CreatedAt
|
||||||
dst.UpdatedAt = src.UpdatedAt
|
dst.UpdatedAt = src.UpdatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
|
||||||
|
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
update := client.User.UpdateOneID(userID)
|
||||||
|
if encryptedSecret == nil {
|
||||||
|
update = update.ClearTotpSecretEncrypted()
|
||||||
|
} else {
|
||||||
|
update = update.SetTotpSecretEncrypted(*encryptedSecret)
|
||||||
|
}
|
||||||
|
_, err := update.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableTotp 启用用户的 TOTP 双因素认证
|
||||||
|
func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
_, err := client.User.UpdateOneID(userID).
|
||||||
|
SetTotpEnabled(true).
|
||||||
|
SetTotpEnabledAt(time.Now()).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableTotp 禁用用户的 TOTP 双因素认证
|
||||||
|
func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
|
||||||
|
client := clientFromContext(ctx, r.client)
|
||||||
|
_, err := client.User.UpdateOneID(userID).
|
||||||
|
SetTotpEnabled(false).
|
||||||
|
ClearTotpEnabledAt().
|
||||||
|
ClearTotpSecretEncrypted().
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
|
|||||||
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
q := client.UserSubscription.Query()
|
q := client.UserSubscription.Query()
|
||||||
if userID != nil {
|
if userID != nil {
|
||||||
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
|||||||
if groupID != nil {
|
if groupID != nil {
|
||||||
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
q = q.Where(usersubscription.GroupIDEQ(*groupID))
|
||||||
}
|
}
|
||||||
if status != "" {
|
|
||||||
|
// Status filtering with real-time expiration check
|
||||||
|
now := time.Now()
|
||||||
|
switch status {
|
||||||
|
case service.SubscriptionStatusActive:
|
||||||
|
// Active: status is active AND not yet expired
|
||||||
|
q = q.Where(
|
||||||
|
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||||
|
usersubscription.ExpiresAtGT(now),
|
||||||
|
)
|
||||||
|
case service.SubscriptionStatusExpired:
|
||||||
|
// Expired: status is expired OR (status is active but already expired)
|
||||||
|
q = q.Where(
|
||||||
|
usersubscription.Or(
|
||||||
|
usersubscription.StatusEQ(service.SubscriptionStatusExpired),
|
||||||
|
usersubscription.And(
|
||||||
|
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||||
|
usersubscription.ExpiresAtLTE(now),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
case "":
|
||||||
|
// No filter
|
||||||
|
default:
|
||||||
|
// Other status (e.g., revoked)
|
||||||
q = q.Where(usersubscription.StatusEQ(status))
|
q = q.Where(usersubscription.StatusEQ(status))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply sorting
|
||||||
|
q = q.WithUser().WithGroup().WithAssignedByUser()
|
||||||
|
|
||||||
|
// Determine sort field
|
||||||
|
var field string
|
||||||
|
switch sortBy {
|
||||||
|
case "expires_at":
|
||||||
|
field = usersubscription.FieldExpiresAt
|
||||||
|
case "status":
|
||||||
|
field = usersubscription.FieldStatus
|
||||||
|
default:
|
||||||
|
field = usersubscription.FieldCreatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine sort order (default: desc)
|
||||||
|
if sortOrder == "asc" && sortBy != "" {
|
||||||
|
q = q.Order(dbent.Asc(field))
|
||||||
|
} else {
|
||||||
|
q = q.Order(dbent.Desc(field))
|
||||||
|
}
|
||||||
|
|
||||||
subs, err := q.
|
subs, err := q.
|
||||||
WithUser().
|
|
||||||
WithGroup().
|
|
||||||
WithAssignedByUser().
|
|
||||||
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
|
|
||||||
Offset(params.Offset()).
|
Offset(params.Offset()).
|
||||||
Limit(params.Limit()).
|
Limit(params.Limit()).
|
||||||
All(ctx)
|
All(ctx)
|
||||||
|
|||||||
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
|||||||
group := s.mustCreateGroup("g-list")
|
group := s.mustCreateGroup("g-list")
|
||||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||||
|
|
||||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
|
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
|
||||||
s.Require().NoError(err, "List")
|
s.Require().NoError(err, "List")
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(int64(1), page.Total)
|
s.Require().Equal(int64(1), page.Total)
|
||||||
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
|||||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||||
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
|||||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||||
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
|||||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||||
})
|
})
|
||||||
|
|
||||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Len(subs, 1)
|
s.Require().Len(subs, 1)
|
||||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||||
|
|||||||
@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewSchedulerCache,
|
NewSchedulerCache,
|
||||||
NewSchedulerOutboxRepository,
|
NewSchedulerOutboxRepository,
|
||||||
NewProxyLatencyCache,
|
NewProxyLatencyCache,
|
||||||
|
NewTotpCache,
|
||||||
|
|
||||||
|
// Encryptors
|
||||||
|
NewAESEncryptor,
|
||||||
|
|
||||||
// HTTP service ports (DI Strategy A: return interface directly)
|
// HTTP service ports (DI Strategy A: return interface directly)
|
||||||
NewTurnstileVerifier,
|
NewTurnstileVerifier,
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"id": 1,
|
"id": 1,
|
||||||
"email": "alice@example.com",
|
"email": "alice@example.com",
|
||||||
"username": "alice",
|
"username": "alice",
|
||||||
"notes": "hello",
|
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"balance": 12.5,
|
"balance": 12.5,
|
||||||
"concurrency": 5,
|
"concurrency": 5,
|
||||||
@@ -131,6 +130,153 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "GET /api/v1/groups/available",
|
||||||
|
setup: func(t *testing.T, deps *contractDeps) {
|
||||||
|
t.Helper()
|
||||||
|
// 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count)。
|
||||||
|
deps.groupRepo.SetActive([]service.Group{
|
||||||
|
{
|
||||||
|
ID: 10,
|
||||||
|
Name: "Group One",
|
||||||
|
Description: "desc",
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
RateMultiplier: 1.5,
|
||||||
|
IsExclusive: false,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
SubscriptionType: service.SubscriptionTypeStandard,
|
||||||
|
ModelRoutingEnabled: true,
|
||||||
|
ModelRouting: map[string][]int64{
|
||||||
|
"claude-3-*": []int64{101, 102},
|
||||||
|
},
|
||||||
|
AccountCount: 2,
|
||||||
|
CreatedAt: deps.now,
|
||||||
|
UpdatedAt: deps.now,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
deps.userSubRepo.SetActiveByUserID(1, nil)
|
||||||
|
},
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/v1/groups/available",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantJSON: `{
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": 10,
|
||||||
|
"name": "Group One",
|
||||||
|
"description": "desc",
|
||||||
|
"platform": "anthropic",
|
||||||
|
"rate_multiplier": 1.5,
|
||||||
|
"is_exclusive": false,
|
||||||
|
"status": "active",
|
||||||
|
"subscription_type": "standard",
|
||||||
|
"daily_limit_usd": null,
|
||||||
|
"weekly_limit_usd": null,
|
||||||
|
"monthly_limit_usd": null,
|
||||||
|
"image_price_1k": null,
|
||||||
|
"image_price_2k": null,
|
||||||
|
"image_price_4k": null,
|
||||||
|
"claude_code_only": false,
|
||||||
|
"fallback_group_id": null,
|
||||||
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
|
"updated_at": "2025-01-02T03:04:05Z"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GET /api/v1/subscriptions",
|
||||||
|
setup: func(t *testing.T, deps *contractDeps) {
|
||||||
|
t.Helper()
|
||||||
|
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
|
||||||
|
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
|
||||||
|
{
|
||||||
|
ID: 501,
|
||||||
|
UserID: 1,
|
||||||
|
GroupID: 10,
|
||||||
|
StartsAt: deps.now,
|
||||||
|
ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
|
||||||
|
Status: service.SubscriptionStatusActive,
|
||||||
|
DailyUsageUSD: 1.23,
|
||||||
|
WeeklyUsageUSD: 2.34,
|
||||||
|
MonthlyUsageUSD: 3.45,
|
||||||
|
AssignedBy: ptr(int64(999)),
|
||||||
|
AssignedAt: deps.now,
|
||||||
|
Notes: "admin-note",
|
||||||
|
CreatedAt: deps.now,
|
||||||
|
UpdatedAt: deps.now,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/v1/subscriptions",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantJSON: `{
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": 501,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": 10,
|
||||||
|
"starts_at": "2025-01-02T03:04:05Z",
|
||||||
|
"expires_at": "2099-01-02T03:04:05Z",
|
||||||
|
"status": "active",
|
||||||
|
"daily_window_start": null,
|
||||||
|
"weekly_window_start": null,
|
||||||
|
"monthly_window_start": null,
|
||||||
|
"daily_usage_usd": 1.23,
|
||||||
|
"weekly_usage_usd": 2.34,
|
||||||
|
"monthly_usage_usd": 3.45,
|
||||||
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
|
"updated_at": "2025-01-02T03:04:05Z"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GET /api/v1/redeem/history",
|
||||||
|
setup: func(t *testing.T, deps *contractDeps) {
|
||||||
|
t.Helper()
|
||||||
|
// 普通用户兑换历史不应包含 notes 等内部字段。
|
||||||
|
deps.redeemRepo.SetByUser(1, []service.RedeemCode{
|
||||||
|
{
|
||||||
|
ID: 900,
|
||||||
|
Code: "CODE-123",
|
||||||
|
Type: service.RedeemTypeBalance,
|
||||||
|
Value: 1.25,
|
||||||
|
Status: service.StatusUsed,
|
||||||
|
UsedBy: ptr(int64(1)),
|
||||||
|
UsedAt: ptr(deps.now),
|
||||||
|
Notes: "internal-note",
|
||||||
|
CreatedAt: deps.now,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/api/v1/redeem/history",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantJSON: `{
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": 900,
|
||||||
|
"code": "CODE-123",
|
||||||
|
"type": "balance",
|
||||||
|
"value": 1.25,
|
||||||
|
"status": "used",
|
||||||
|
"used_by": 1,
|
||||||
|
"used_at": "2025-01-02T03:04:05Z",
|
||||||
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
|
"group_id": null,
|
||||||
|
"validity_days": 0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "GET /api/v1/usage/stats",
|
name: "GET /api/v1/usage/stats",
|
||||||
setup: func(t *testing.T, deps *contractDeps) {
|
setup: func(t *testing.T, deps *contractDeps) {
|
||||||
@@ -190,24 +336,25 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
|
deps.usageRepo.SetUserLogs(1, []service.UsageLog{
|
||||||
{
|
{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
APIKeyID: 100,
|
APIKeyID: 100,
|
||||||
AccountID: 200,
|
AccountID: 200,
|
||||||
RequestID: "req_123",
|
AccountRateMultiplier: ptr(0.5),
|
||||||
Model: "claude-3",
|
RequestID: "req_123",
|
||||||
InputTokens: 10,
|
Model: "claude-3",
|
||||||
OutputTokens: 20,
|
InputTokens: 10,
|
||||||
CacheCreationTokens: 1,
|
OutputTokens: 20,
|
||||||
CacheReadTokens: 2,
|
CacheCreationTokens: 1,
|
||||||
TotalCost: 0.5,
|
CacheReadTokens: 2,
|
||||||
ActualCost: 0.5,
|
TotalCost: 0.5,
|
||||||
RateMultiplier: 1,
|
ActualCost: 0.5,
|
||||||
BillingType: service.BillingTypeBalance,
|
RateMultiplier: 1,
|
||||||
Stream: true,
|
BillingType: service.BillingTypeBalance,
|
||||||
DurationMs: ptr(100),
|
Stream: true,
|
||||||
FirstTokenMs: ptr(50),
|
DurationMs: ptr(100),
|
||||||
CreatedAt: deps.now,
|
FirstTokenMs: ptr(50),
|
||||||
|
CreatedAt: deps.now,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
@@ -238,10 +385,9 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"output_cost": 0,
|
"output_cost": 0,
|
||||||
"cache_creation_cost": 0,
|
"cache_creation_cost": 0,
|
||||||
"cache_read_cost": 0,
|
"cache_read_cost": 0,
|
||||||
"total_cost": 0.5,
|
"total_cost": 0.5,
|
||||||
"actual_cost": 0.5,
|
"actual_cost": 0.5,
|
||||||
"rate_multiplier": 1,
|
"rate_multiplier": 1,
|
||||||
"account_rate_multiplier": null,
|
|
||||||
"billing_type": 0,
|
"billing_type": 0,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
"duration_ms": 100,
|
"duration_ms": 100,
|
||||||
@@ -266,6 +412,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
deps.settingRepo.SetAll(map[string]string{
|
deps.settingRepo.SetAll(map[string]string{
|
||||||
service.SettingKeyRegistrationEnabled: "true",
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
service.SettingKeyEmailVerifyEnabled: "false",
|
service.SettingKeyEmailVerifyEnabled: "false",
|
||||||
|
service.SettingKeyPromoCodeEnabled: "true",
|
||||||
|
|
||||||
service.SettingKeySMTPHost: "smtp.example.com",
|
service.SettingKeySMTPHost: "smtp.example.com",
|
||||||
service.SettingKeySMTPPort: "587",
|
service.SettingKeySMTPPort: "587",
|
||||||
@@ -304,6 +451,10 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"data": {
|
"data": {
|
||||||
"registration_enabled": true,
|
"registration_enabled": true,
|
||||||
"email_verify_enabled": false,
|
"email_verify_enabled": false,
|
||||||
|
"promo_code_enabled": true,
|
||||||
|
"password_reset_enabled": false,
|
||||||
|
"totp_enabled": false,
|
||||||
|
"totp_encryption_key_configured": false,
|
||||||
"smtp_host": "smtp.example.com",
|
"smtp_host": "smtp.example.com",
|
||||||
"smtp_port": 587,
|
"smtp_port": 587,
|
||||||
"smtp_username": "user",
|
"smtp_username": "user",
|
||||||
@@ -337,7 +488,10 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"fallback_model_openai": "gpt-4o",
|
"fallback_model_openai": "gpt-4o",
|
||||||
"enable_identity_patch": true,
|
"enable_identity_patch": true,
|
||||||
"identity_patch_prompt": "",
|
"identity_patch_prompt": "",
|
||||||
"home_content": ""
|
"home_content": "",
|
||||||
|
"hide_ccs_import_button": false,
|
||||||
|
"purchase_subscription_enabled": false,
|
||||||
|
"purchase_subscription_url": ""
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
@@ -385,8 +539,11 @@ type contractDeps struct {
|
|||||||
now time.Time
|
now time.Time
|
||||||
router http.Handler
|
router http.Handler
|
||||||
apiKeyRepo *stubApiKeyRepo
|
apiKeyRepo *stubApiKeyRepo
|
||||||
|
groupRepo *stubGroupRepo
|
||||||
|
userSubRepo *stubUserSubscriptionRepo
|
||||||
usageRepo *stubUsageLogRepo
|
usageRepo *stubUsageLogRepo
|
||||||
settingRepo *stubSettingRepo
|
settingRepo *stubSettingRepo
|
||||||
|
redeemRepo *stubRedeemCodeRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
func newContractDeps(t *testing.T) *contractDeps {
|
func newContractDeps(t *testing.T) *contractDeps {
|
||||||
@@ -414,11 +571,11 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
|
|
||||||
apiKeyRepo := newStubApiKeyRepo(now)
|
apiKeyRepo := newStubApiKeyRepo(now)
|
||||||
apiKeyCache := stubApiKeyCache{}
|
apiKeyCache := stubApiKeyCache{}
|
||||||
groupRepo := stubGroupRepo{}
|
groupRepo := &stubGroupRepo{}
|
||||||
userSubRepo := stubUserSubscriptionRepo{}
|
userSubRepo := &stubUserSubscriptionRepo{}
|
||||||
accountRepo := stubAccountRepo{}
|
accountRepo := stubAccountRepo{}
|
||||||
proxyRepo := stubProxyRepo{}
|
proxyRepo := stubProxyRepo{}
|
||||||
redeemRepo := stubRedeemCodeRepo{}
|
redeemRepo := &stubRedeemCodeRepo{}
|
||||||
|
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Default: config.DefaultConfig{
|
Default: config.DefaultConfig{
|
||||||
@@ -433,15 +590,21 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
usageRepo := newStubUsageLogRepo()
|
usageRepo := newStubUsageLogRepo()
|
||||||
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||||
|
|
||||||
|
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
|
||||||
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
|
|
||||||
|
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
|
||||||
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
|
|
||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
jwtAuth := func(c *gin.Context) {
|
jwtAuth := func(c *gin.Context) {
|
||||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||||
@@ -472,12 +635,21 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
v1Keys.Use(jwtAuth)
|
v1Keys.Use(jwtAuth)
|
||||||
v1Keys.GET("/keys", apiKeyHandler.List)
|
v1Keys.GET("/keys", apiKeyHandler.List)
|
||||||
v1Keys.POST("/keys", apiKeyHandler.Create)
|
v1Keys.POST("/keys", apiKeyHandler.Create)
|
||||||
|
v1Keys.GET("/groups/available", apiKeyHandler.GetAvailableGroups)
|
||||||
|
|
||||||
v1Usage := v1.Group("")
|
v1Usage := v1.Group("")
|
||||||
v1Usage.Use(jwtAuth)
|
v1Usage.Use(jwtAuth)
|
||||||
v1Usage.GET("/usage", usageHandler.List)
|
v1Usage.GET("/usage", usageHandler.List)
|
||||||
v1Usage.GET("/usage/stats", usageHandler.Stats)
|
v1Usage.GET("/usage/stats", usageHandler.Stats)
|
||||||
|
|
||||||
|
v1Subs := v1.Group("")
|
||||||
|
v1Subs.Use(jwtAuth)
|
||||||
|
v1Subs.GET("/subscriptions", subscriptionHandler.List)
|
||||||
|
|
||||||
|
v1Redeem := v1.Group("")
|
||||||
|
v1Redeem.Use(jwtAuth)
|
||||||
|
v1Redeem.GET("/redeem/history", redeemHandler.GetHistory)
|
||||||
|
|
||||||
v1Admin := v1.Group("/admin")
|
v1Admin := v1.Group("/admin")
|
||||||
v1Admin.Use(adminAuth)
|
v1Admin.Use(adminAuth)
|
||||||
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
|
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
|
||||||
@@ -487,8 +659,11 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
now: now,
|
now: now,
|
||||||
router: r,
|
router: r,
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
userSubRepo: userSubRepo,
|
||||||
usageRepo: usageRepo,
|
usageRepo: usageRepo,
|
||||||
settingRepo: settingRepo,
|
settingRepo: settingRepo,
|
||||||
|
redeemRepo: redeemRepo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -584,6 +759,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubApiKeyCache struct{}
|
type stubApiKeyCache struct{}
|
||||||
|
|
||||||
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||||
@@ -626,7 +813,13 @@ func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handl
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubGroupRepo struct{}
|
type stubGroupRepo struct {
|
||||||
|
active []service.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubGroupRepo) SetActive(groups []service.Group) {
|
||||||
|
r.active = append([]service.Group(nil), groups...)
|
||||||
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
@@ -660,12 +853,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
|
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||||
return nil, errors.New("not implemented")
|
return append([]service.Group(nil), r.active...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||||
return nil, errors.New("not implemented")
|
out := make([]service.Group, 0, len(r.active))
|
||||||
|
for i := range r.active {
|
||||||
|
g := r.active[i]
|
||||||
|
if g.Platform == platform {
|
||||||
|
out = append(out, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
@@ -883,7 +1083,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubRedeemCodeRepo struct{}
|
type stubRedeemCodeRepo struct {
|
||||||
|
byUser map[int64][]service.RedeemCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubRedeemCodeRepo) SetByUser(userID int64, codes []service.RedeemCode) {
|
||||||
|
if r.byUser == nil {
|
||||||
|
r.byUser = make(map[int64][]service.RedeemCode)
|
||||||
|
}
|
||||||
|
r.byUser[userID] = append([]service.RedeemCode(nil), codes...)
|
||||||
|
}
|
||||||
|
|
||||||
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
@@ -921,11 +1130,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||||
return nil, errors.New("not implemented")
|
if r.byUser == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
codes := r.byUser[userID]
|
||||||
|
if limit > 0 && len(codes) > limit {
|
||||||
|
codes = codes[:limit]
|
||||||
|
}
|
||||||
|
return append([]service.RedeemCode(nil), codes...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubUserSubscriptionRepo struct{}
|
type stubUserSubscriptionRepo struct {
|
||||||
|
byUser map[int64][]service.UserSubscription
|
||||||
|
activeByUser map[int64][]service.UserSubscription
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubUserSubscriptionRepo) SetByUserID(userID int64, subs []service.UserSubscription) {
|
||||||
|
if r.byUser == nil {
|
||||||
|
r.byUser = make(map[int64][]service.UserSubscription)
|
||||||
|
}
|
||||||
|
r.byUser[userID] = append([]service.UserSubscription(nil), subs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubUserSubscriptionRepo) SetActiveByUserID(userID int64, subs []service.UserSubscription) {
|
||||||
|
if r.activeByUser == nil {
|
||||||
|
r.activeByUser = make(map[int64][]service.UserSubscription)
|
||||||
|
}
|
||||||
|
r.activeByUser[userID] = append([]service.UserSubscription(nil), subs...)
|
||||||
|
}
|
||||||
|
|
||||||
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
@@ -945,16 +1178,22 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub
|
|||||||
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||||
return nil, errors.New("not implemented")
|
if r.byUser == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return append([]service.UserSubscription(nil), r.byUser[userID]...), nil
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||||
return nil, errors.New("not implemented")
|
if r.activeByUser == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return append([]service.UserSubscription(nil), r.activeByUser[userID]...), nil
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||||
|
|||||||
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
|
|||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||||
return nil, nil, errors.New("not implemented")
|
return nil, nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
|
|||||||
{
|
{
|
||||||
auth.POST("/register", h.Auth.Register)
|
auth.POST("/register", h.Auth.Register)
|
||||||
auth.POST("/login", h.Auth.Login)
|
auth.POST("/login", h.Auth.Login)
|
||||||
|
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||||
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||||
FailureMode: middleware.RateLimitFailClose,
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
}), h.Auth.ValidatePromoCode)
|
}), h.Auth.ValidatePromoCode)
|
||||||
|
// 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
|
||||||
|
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.ForgotPassword)
|
||||||
|
// 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||||
|
auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.ResetPassword)
|
||||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
|
|||||||
user.GET("/profile", h.User.GetProfile)
|
user.GET("/profile", h.User.GetProfile)
|
||||||
user.PUT("/password", h.User.ChangePassword)
|
user.PUT("/password", h.User.ChangePassword)
|
||||||
user.PUT("", h.User.UpdateProfile)
|
user.PUT("", h.User.UpdateProfile)
|
||||||
|
|
||||||
|
// TOTP 双因素认证
|
||||||
|
totp := user.Group("/totp")
|
||||||
|
{
|
||||||
|
totp.GET("/status", h.Totp.GetStatus)
|
||||||
|
totp.GET("/verification-method", h.Totp.GetVerificationMethod)
|
||||||
|
totp.POST("/send-code", h.Totp.SendVerifyCode)
|
||||||
|
totp.POST("/setup", h.Totp.InitiateSetup)
|
||||||
|
totp.POST("/enable", h.Totp.Enable)
|
||||||
|
totp.POST("/disable", h.Totp.Disable)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// API Key管理
|
// API Key管理
|
||||||
|
|||||||
@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCredentialAsInt64 解析凭证中的 int64 字段
|
||||||
|
// 用于读取 _token_version 等内部字段
|
||||||
|
func (a *Account) GetCredentialAsInt64(key string) int64 {
|
||||||
|
if a == nil || a.Credentials == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
val, ok := a.Credentials[key]
|
||||||
|
if !ok || val == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch v := val.(type) {
|
||||||
|
case int64:
|
||||||
|
return v
|
||||||
|
case float64:
|
||||||
|
return int64(v)
|
||||||
|
case int:
|
||||||
|
return int64(v)
|
||||||
|
case json.Number:
|
||||||
|
if i, err := v.Int64(); err == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsTempUnschedulableEnabled() bool {
|
func (a *Account) IsTempUnschedulableEnabled() bool {
|
||||||
if a.Credentials == nil {
|
if a.Credentials == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
|
|||||||
} `json:"seven_day_sonnet"`
|
} `json:"seven_day_sonnet"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
|
||||||
|
type ClaudeUsageFetchOptions struct {
|
||||||
|
AccessToken string // OAuth access token
|
||||||
|
ProxyURL string // 代理 URL(可选)
|
||||||
|
AccountID int64 // 账号 ID(用于 TLS 指纹选择)
|
||||||
|
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
|
||||||
|
Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等)
|
||||||
|
}
|
||||||
|
|
||||||
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||||
type ClaudeUsageFetcher interface {
|
type ClaudeUsageFetcher interface {
|
||||||
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||||
|
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
|
||||||
|
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccountUsageService 账号使用量查询服务
|
// AccountUsageService 账号使用量查询服务
|
||||||
@@ -170,6 +181,7 @@ type AccountUsageService struct {
|
|||||||
geminiQuotaService *GeminiQuotaService
|
geminiQuotaService *GeminiQuotaService
|
||||||
antigravityQuotaFetcher *AntigravityQuotaFetcher
|
antigravityQuotaFetcher *AntigravityQuotaFetcher
|
||||||
cache *UsageCache
|
cache *UsageCache
|
||||||
|
identityCache IdentityCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountUsageService 创建AccountUsageService实例
|
// NewAccountUsageService 创建AccountUsageService实例
|
||||||
@@ -180,6 +192,7 @@ func NewAccountUsageService(
|
|||||||
geminiQuotaService *GeminiQuotaService,
|
geminiQuotaService *GeminiQuotaService,
|
||||||
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||||
cache *UsageCache,
|
cache *UsageCache,
|
||||||
|
identityCache IdentityCache,
|
||||||
) *AccountUsageService {
|
) *AccountUsageService {
|
||||||
return &AccountUsageService{
|
return &AccountUsageService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
@@ -188,6 +201,7 @@ func NewAccountUsageService(
|
|||||||
geminiQuotaService: geminiQuotaService,
|
geminiQuotaService: geminiQuotaService,
|
||||||
antigravityQuotaFetcher: antigravityQuotaFetcher,
|
antigravityQuotaFetcher: antigravityQuotaFetcher,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
|
identityCache: identityCache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
|
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
|
||||||
|
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
|
||||||
|
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
|
||||||
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
|
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
|
||||||
accessToken := account.GetCredential("access_token")
|
accessToken := account.GetCredential("access_token")
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
@@ -435,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
// 构建完整的选项
|
||||||
|
opts := &ClaudeUsageFetchOptions{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
AccountID: account.ID,
|
||||||
|
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
|
||||||
|
if s.identityCache != nil {
|
||||||
|
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
|
||||||
|
opts.Fingerprint = fp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseTime 尝试多种格式解析时间
|
// parseTime 尝试多种格式解析时间
|
||||||
|
|||||||
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
|||||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||||
|
panic("unexpected UpdateTotpSecret call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
|
||||||
|
panic("unexpected EnableTotp call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
|
||||||
|
panic("unexpected DisableTotp call")
|
||||||
|
}
|
||||||
|
|
||||||
type groupRepoStub struct {
|
type groupRepoStub struct {
|
||||||
affectedUserIDs []int64
|
affectedUserIDs []int64
|
||||||
deleteErr error
|
deleteErr error
|
||||||
|
|||||||
@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 清理 Schema
|
||||||
|
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
|
||||||
|
injectedBody = cleanedBody
|
||||||
|
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
|
||||||
|
} else {
|
||||||
|
log.Printf("[Antigravity] Failed to clean schema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// 包装请求
|
// 包装请求
|
||||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
if u := extractGeminiUsage(parsed); u != nil {
|
if u := extractGeminiUsage(parsed); u != nil {
|
||||||
usage = u
|
usage = u
|
||||||
}
|
}
|
||||||
|
// Check for MALFORMED_FUNCTION_CALL
|
||||||
|
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||||
|
if cand, ok := candidates[0].(map[string]any); ok {
|
||||||
|
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||||||
|
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
|
||||||
|
if content, ok := cand["content"]; ok {
|
||||||
|
if b, err := json.Marshal(content); err == nil {
|
||||||
|
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstTokenMs == nil {
|
if firstTokenMs == nil {
|
||||||
@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
|||||||
usage = u
|
usage = u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for MALFORMED_FUNCTION_CALL
|
||||||
|
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||||
|
if cand, ok := candidates[0].(map[string]any); ok {
|
||||||
|
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||||||
|
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
|
||||||
|
if content, ok := cand["content"]; ok {
|
||||||
|
if b, err := json.Marshal(content); err == nil {
|
||||||
|
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 保留最后一个有 parts 的响应
|
// 保留最后一个有 parts 的响应
|
||||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||||
lastWithParts = parsed
|
lastWithParts = parsed
|
||||||
@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
|
|||||||
return result, existingParts, setParts
|
return result, existingParts, setParts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
|
||||||
|
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
|
||||||
|
// 保持原始顺序,只合并连续的普通 text parts
|
||||||
|
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
|
||||||
|
if len(collectedParts) == 0 {
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _, setParts := getOrCreateGeminiParts(response)
|
||||||
|
|
||||||
|
// 合并策略:
|
||||||
|
// 1. 保持原始顺序
|
||||||
|
// 2. 连续的普通 text parts 合并为一个
|
||||||
|
// 3. thinking、functionCall、inlineData 等保持原样
|
||||||
|
var mergedParts []any
|
||||||
|
var textBuffer strings.Builder
|
||||||
|
|
||||||
|
flushTextBuffer := func() {
|
||||||
|
if textBuffer.Len() > 0 {
|
||||||
|
mergedParts = append(mergedParts, map[string]any{
|
||||||
|
"text": textBuffer.String(),
|
||||||
|
})
|
||||||
|
textBuffer.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range collectedParts {
|
||||||
|
// 检查是否是普通 text part
|
||||||
|
if text, ok := part["text"].(string); ok {
|
||||||
|
// 检查是否有 thought 标记
|
||||||
|
if thought, _ := part["thought"].(bool); thought {
|
||||||
|
// thinking part,先刷新 text buffer,然后保留原样
|
||||||
|
flushTextBuffer()
|
||||||
|
mergedParts = append(mergedParts, part)
|
||||||
|
} else {
|
||||||
|
// 普通 text,累积到 buffer
|
||||||
|
_, _ = textBuffer.WriteString(text)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
|
||||||
|
flushTextBuffer()
|
||||||
|
mergedParts = append(mergedParts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新剩余的 text
|
||||||
|
flushTextBuffer()
|
||||||
|
|
||||||
|
setParts(mergedParts)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
|
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
|
||||||
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
|
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
|
||||||
if len(imageParts) == 0 {
|
if len(imageParts) == 0 {
|
||||||
@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
|||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
var last map[string]any
|
var last map[string]any
|
||||||
var lastWithParts map[string]any
|
var lastWithParts map[string]any
|
||||||
|
var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
|
||||||
|
|
||||||
type scanEvent struct {
|
type scanEvent struct {
|
||||||
line string
|
line string
|
||||||
@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
|||||||
|
|
||||||
last = parsed
|
last = parsed
|
||||||
|
|
||||||
// 保留最后一个有 parts 的响应
|
// 保留最后一个有 parts 的响应,并收集所有 parts
|
||||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||||
lastWithParts = parsed
|
lastWithParts = parsed
|
||||||
|
|
||||||
|
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
|
||||||
|
collectedParts = append(collectedParts, parts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
@@ -2252,6 +2343,11 @@ returnResponse:
|
|||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 将收集的所有 parts 合并到最终响应中
|
||||||
|
if len(collectedParts) > 0 {
|
||||||
|
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
|
||||||
|
}
|
||||||
|
|
||||||
// 序列化为 JSON(Gemini 格式)
|
// 序列化为 JSON(Gemini 格式)
|
||||||
geminiBody, err := json.Marshal(finalResponse)
|
geminiBody, err := json.Marshal(finalResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
|
|||||||
modelLower == "gemini-2.5-flash-image-preview" ||
|
modelLower == "gemini-2.5-flash-image-preview" ||
|
||||||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
|
||||||
|
func cleanGeminiRequest(body []byte) ([]byte, error) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
|
||||||
|
// 1. 清理 Tools
|
||||||
|
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
|
||||||
|
for _, t := range tools {
|
||||||
|
toolMap, ok := t.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// function_declarations (snake_case) or functionDeclarations (camelCase)
|
||||||
|
var funcs []any
|
||||||
|
if f, ok := toolMap["functionDeclarations"].([]any); ok {
|
||||||
|
funcs = f
|
||||||
|
} else if f, ok := toolMap["function_declarations"].([]any); ok {
|
||||||
|
funcs = f
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(funcs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range funcs {
|
||||||
|
funcMap, ok := f.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if params, ok := funcMap["parameters"].(map[string]any); ok {
|
||||||
|
antigravity.DeepCleanUndefined(params)
|
||||||
|
cleaned := antigravity.CleanJSONSchema(params)
|
||||||
|
funcMap["parameters"] = cleaned
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modified {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(payload)
|
||||||
|
}
|
||||||
|
|||||||
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
|||||||
result.Email = userInfo.Email
|
result.Email = userInfo.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 project_id(部分账户类型可能没有)
|
// 获取 project_id(部分账户类型可能没有),失败时重试
|
||||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
|
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
|
||||||
if err != nil {
|
if loadErr != nil {
|
||||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
|
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
|
||||||
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
result.ProjectIDMissing = true
|
||||||
result.ProjectID = loadResp.CloudAICompanionProject
|
} else {
|
||||||
|
result.ProjectID = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
|||||||
tokenInfo.Email = existingEmail
|
tokenInfo.Email = existingEmail
|
||||||
}
|
}
|
||||||
|
|
||||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id
|
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
|
||||||
client := antigravity.NewClient(proxyURL)
|
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
|
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
|
||||||
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
|
|
||||||
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
|
if loadErr != nil {
|
||||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
// LoadCodeAssist 失败,保留原有 project_id
|
||||||
tokenInfo.ProjectID = existingProjectID
|
tokenInfo.ProjectID = existingProjectID
|
||||||
tokenInfo.ProjectIDMissing = true
|
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
|
||||||
|
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
|
||||||
|
if existingProjectID == "" {
|
||||||
|
tokenInfo.ProjectIDMissing = true
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
|
tokenInfo.ProjectID = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenInfo, nil
|
return tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadProjectIDWithRetry 带重试机制获取 project_id
|
||||||
|
// 返回 project_id 和错误,失败时会重试指定次数
|
||||||
|
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
// 指数退避:1s, 2s, 4s
|
||||||
|
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
|
||||||
|
if backoff > 8*time.Second {
|
||||||
|
backoff = 8 * time.Second
|
||||||
|
}
|
||||||
|
time.Sleep(backoff)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := antigravity.NewClient(proxyURL)
|
||||||
|
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||||
|
|
||||||
|
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||||
|
return loadResp.CloudAICompanionProject, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录错误
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
} else if loadResp == nil {
|
||||||
|
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
|
||||||
|
} else {
|
||||||
|
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
// BuildAccountCredentials 构建账户凭证
|
// BuildAccountCredentials 构建账户凭证
|
||||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||||
creds := map[string]any{
|
creds := map[string]any{
|
||||||
|
|||||||
@@ -94,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
|||||||
|
|
||||||
var handleErrorCalled bool
|
var handleErrorCalled bool
|
||||||
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
|
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
|
||||||
prefix: "[test]",
|
prefix: "[test]",
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
account: account,
|
account: account,
|
||||||
proxyURL: "",
|
proxyURL: "",
|
||||||
accessToken: "token",
|
accessToken: "token",
|
||||||
action: "generateContent",
|
action: "generateContent",
|
||||||
body: []byte(`{"input":"test"}`),
|
body: []byte(`{"input":"test"}`),
|
||||||
quotaScope: AntigravityQuotaScopeClaude,
|
quotaScope: AntigravityQuotaScopeClaude,
|
||||||
httpUpstream: upstream,
|
httpUpstream: upstream,
|
||||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||||
handleErrorCalled = true
|
handleErrorCalled = true
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 存入缓存
|
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
ttl := 30 * time.Minute
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if expiresAt != nil {
|
if isStale && latestAccount != nil {
|
||||||
until := time.Until(*expiresAt)
|
// 版本过时,使用 DB 中的最新 token
|
||||||
switch {
|
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
case until > antigravityTokenCacheSkew:
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
ttl = until - antigravityTokenCacheSkew
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
case until > 0:
|
return "", errors.New("access_token not found after version check")
|
||||||
ttl = until
|
|
||||||
default:
|
|
||||||
ttl = time.Minute
|
|
||||||
}
|
}
|
||||||
|
// 不写入缓存,让下次请求重新处理
|
||||||
|
} else {
|
||||||
|
ttl := 30 * time.Minute
|
||||||
|
if expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
switch {
|
||||||
|
case until > antigravityTokenCacheSkew:
|
||||||
|
ttl = until - antigravityTokenCacheSkew
|
||||||
|
case until > 0:
|
||||||
|
ttl = until
|
||||||
|
default:
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||||
}
|
}
|
||||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
|
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
||||||
for k, v := range account.Credentials {
|
for k, v := range account.Credentials {
|
||||||
if _, exists := newCredentials[k]; !exists {
|
if _, exists := newCredentials[k]; !exists {
|
||||||
newCredentials[k] = v
|
newCredentials[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
|
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||||
|
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||||
|
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||||
|
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||||
|
newCredentials["project_id"] = oldProjectID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果 project_id 获取失败,只记录警告,不返回错误
|
||||||
|
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
|
||||||
|
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
|
||||||
if tokenInfo.ProjectIDMissing {
|
if tokenInfo.ProjectIDMissing {
|
||||||
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
|
if tokenInfo.ProjectID != "" {
|
||||||
|
// 有旧的 project_id,本次获取失败,保留旧值
|
||||||
|
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
|
||||||
|
} else {
|
||||||
|
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
|
||||||
|
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newCredentials, nil
|
return newCredentials, nil
|
||||||
|
|||||||
@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
|||||||
return "", nil, ErrServiceUnavailable
|
return "", nil, ErrServiceUnavailable
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用优惠码(如果提供)
|
// 应用优惠码(如果提供且功能已启用)
|
||||||
if promoCode != "" && s.promoService != nil {
|
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||||
// 优惠码应用失败不影响注册,只记录日志
|
// 优惠码应用失败不影响注册,只记录日志
|
||||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||||
@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
|||||||
// 生成新token
|
// 生成新token
|
||||||
return s.GenerateToken(user)
|
return s.GenerateToken(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||||
|
// 要求:必须同时开启邮件验证且 SMTP 配置正确
|
||||||
|
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||||
|
if s.settingService == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Must have email verification enabled and SMTP configured
|
||||||
|
if !s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.settingService.IsPasswordResetEnabled(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// preparePasswordReset validates the password reset request and returns necessary data
|
||||||
|
// Returns (siteName, resetURL, shouldProceed)
|
||||||
|
// shouldProceed is false when we should silently return success (to prevent enumeration)
|
||||||
|
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
|
||||||
|
// Check if user exists (but don't reveal this to the caller)
|
||||||
|
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
|
// Security: Log but don't reveal that user doesn't exist
|
||||||
|
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
log.Printf("[Auth] Database error checking email for password reset: %v", err)
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user is active
|
||||||
|
if !user.IsActive() {
|
||||||
|
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get site name
|
||||||
|
siteName := "Sub2API"
|
||||||
|
if s.settingService != nil {
|
||||||
|
siteName = s.settingService.GetSiteName(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build reset URL base
|
||||||
|
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
|
||||||
|
|
||||||
|
return siteName, resetURL, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestPasswordReset 请求密码重置(同步发送)
|
||||||
|
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||||
|
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
|
||||||
|
if !s.IsPasswordResetEnabled(ctx) {
|
||||||
|
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||||
|
}
|
||||||
|
if s.emailService == nil {
|
||||||
|
return ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
|
||||||
|
if !shouldProceed {
|
||||||
|
return nil // Silent success to prevent enumeration
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||||
|
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
|
||||||
|
return nil // Silent success to prevent enumeration
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[Auth] Password reset email sent to: %s", email)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
|
||||||
|
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||||
|
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
|
||||||
|
if !s.IsPasswordResetEnabled(ctx) {
|
||||||
|
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||||
|
}
|
||||||
|
if s.emailQueueService == nil {
|
||||||
|
return ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
|
||||||
|
if !shouldProceed {
|
||||||
|
return nil // Silent success to prevent enumeration
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
|
||||||
|
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
||||||
|
return nil // Silent success to prevent enumeration
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[Auth] Password reset email enqueued for: %s", email)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetPassword 重置密码
|
||||||
|
// Security: Increments TokenVersion to invalidate all existing JWT tokens
|
||||||
|
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
|
||||||
|
// Check if password reset is enabled
|
||||||
|
if !s.IsPasswordResetEnabled(ctx) {
|
||||||
|
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.emailService == nil {
|
||||||
|
return ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify and consume the reset token (one-time use)
|
||||||
|
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user
|
||||||
|
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
|
return ErrInvalidResetToken // Token was valid but user was deleted
|
||||||
|
}
|
||||||
|
log.Printf("[Auth] Database error getting user for password reset: %v", err)
|
||||||
|
return ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user is active
|
||||||
|
if !user.IsActive() {
|
||||||
|
return ErrUserNotActive
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash new password
|
||||||
|
hashedPassword, err := s.HashPassword(newPassword)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update password and increment TokenVersion
|
||||||
|
user.PasswordHash = hashedPassword
|
||||||
|
user.TokenVersion++ // Invalidate all existing tokens
|
||||||
|
|
||||||
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
|
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
|
||||||
|
return ErrServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[Auth] Password reset successful for user: %s", email)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
JWT: config.JWTConfig{
|
JWT: config.JWTConfig{
|
||||||
|
|||||||
@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
return "", errors.New("access_token not found in credentials")
|
return "", errors.New("access_token not found in credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 存入缓存
|
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
ttl := 30 * time.Minute
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if refreshFailed {
|
if isStale && latestAccount != nil {
|
||||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
// 版本过时,使用 DB 中的最新 token
|
||||||
ttl = time.Minute
|
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
} else if expiresAt != nil {
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
until := time.Until(*expiresAt)
|
return "", errors.New("access_token not found after version check")
|
||||||
switch {
|
}
|
||||||
case until > claudeTokenCacheSkew:
|
// 不写入缓存,让下次请求重新处理
|
||||||
ttl = until - claudeTokenCacheSkew
|
} else {
|
||||||
case until > 0:
|
ttl := 30 * time.Minute
|
||||||
ttl = until
|
if refreshFailed {
|
||||||
default:
|
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||||
ttl = time.Minute
|
ttl = time.Minute
|
||||||
|
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||||
|
} else if expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
switch {
|
||||||
|
case until > claudeTokenCacheSkew:
|
||||||
|
ttl = until - claudeTokenCacheSkew
|
||||||
|
case until > 0:
|
||||||
|
ttl = until
|
||||||
|
default:
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||||
|
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
|
||||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -69,8 +69,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
|||||||
// Setting keys
|
// Setting keys
|
||||||
const (
|
const (
|
||||||
// 注册设置
|
// 注册设置
|
||||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||||
|
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||||
|
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||||
@@ -86,6 +88,9 @@ const (
|
|||||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||||
|
|
||||||
|
// TOTP 双因素认证设置
|
||||||
|
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
|
||||||
|
|
||||||
// LinuxDo Connect OAuth 登录设置
|
// LinuxDo Connect OAuth 登录设置
|
||||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||||
@@ -93,13 +98,16 @@ const (
|
|||||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||||
|
|
||||||
// OEM设置
|
// OEM设置
|
||||||
SettingKeySiteName = "site_name" // 网站名称
|
SettingKeySiteName = "site_name" // 网站名称
|
||||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||||
SettingKeyDocURL = "doc_url" // 文档链接
|
SettingKeyDocURL = "doc_url" // 文档链接
|
||||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||||
|
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
|
||||||
|
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口
|
||||||
|
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
|
||||||
|
|
||||||
// 默认配置
|
// 默认配置
|
||||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||||
|
|||||||
@@ -8,11 +8,18 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Task type constants
|
||||||
|
const (
|
||||||
|
TaskTypeVerifyCode = "verify_code"
|
||||||
|
TaskTypePasswordReset = "password_reset"
|
||||||
|
)
|
||||||
|
|
||||||
// EmailTask 邮件发送任务
|
// EmailTask 邮件发送任务
|
||||||
type EmailTask struct {
|
type EmailTask struct {
|
||||||
Email string
|
Email string
|
||||||
SiteName string
|
SiteName string
|
||||||
TaskType string // "verify_code"
|
TaskType string // "verify_code" or "password_reset"
|
||||||
|
ResetURL string // Only used for password_reset task type
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmailQueueService 异步邮件队列服务
|
// EmailQueueService 异步邮件队列服务
|
||||||
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
switch task.TaskType {
|
switch task.TaskType {
|
||||||
case "verify_code":
|
case TaskTypeVerifyCode:
|
||||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
||||||
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||||
}
|
}
|
||||||
|
case TaskTypePasswordReset:
|
||||||
|
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
|
||||||
|
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
||||||
|
} else {
|
||||||
|
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
||||||
}
|
}
|
||||||
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
|||||||
task := EmailTask{
|
task := EmailTask{
|
||||||
Email: email,
|
Email: email,
|
||||||
SiteName: siteName,
|
SiteName: siteName,
|
||||||
TaskType: "verify_code",
|
TaskType: TaskTypeVerifyCode,
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnqueuePasswordReset 将密码重置邮件任务加入队列
|
||||||
|
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
|
||||||
|
task := EmailTask{
|
||||||
|
Email: email,
|
||||||
|
SiteName: siteName,
|
||||||
|
TaskType: TaskTypePasswordReset,
|
||||||
|
ResetURL: resetURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s.taskChan <- task:
|
||||||
|
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("email queue is full")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Stop 停止队列服务
|
// Stop 停止队列服务
|
||||||
func (s *EmailQueueService) Stop() {
|
func (s *EmailQueueService) Stop() {
|
||||||
close(s.stopChan)
|
close(s.stopChan)
|
||||||
|
|||||||
@@ -3,11 +3,14 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/smtp"
|
"net/smtp"
|
||||||
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,6 +22,9 @@ var (
|
|||||||
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
|
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
|
||||||
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
|
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
|
||||||
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
|
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
|
||||||
|
|
||||||
|
// Password reset errors
|
||||||
|
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
|
||||||
)
|
)
|
||||||
|
|
||||||
// EmailCache defines cache operations for email service
|
// EmailCache defines cache operations for email service
|
||||||
@@ -26,6 +32,16 @@ type EmailCache interface {
|
|||||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||||
DeleteVerificationCode(ctx context.Context, email string) error
|
DeleteVerificationCode(ctx context.Context, email string) error
|
||||||
|
|
||||||
|
// Password reset token methods
|
||||||
|
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
|
||||||
|
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
|
||||||
|
DeletePasswordResetToken(ctx context.Context, email string) error
|
||||||
|
|
||||||
|
// Password reset email cooldown methods
|
||||||
|
// Returns true if in cooldown period (email was sent recently)
|
||||||
|
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
|
||||||
|
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerificationCodeData represents verification code data
|
// VerificationCodeData represents verification code data
|
||||||
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
|
|||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PasswordResetTokenData represents password reset token data
|
||||||
|
type PasswordResetTokenData struct {
|
||||||
|
Token string
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
verifyCodeTTL = 15 * time.Minute
|
verifyCodeTTL = 15 * time.Minute
|
||||||
verifyCodeCooldown = 1 * time.Minute
|
verifyCodeCooldown = 1 * time.Minute
|
||||||
maxVerifyCodeAttempts = 5
|
maxVerifyCodeAttempts = 5
|
||||||
|
|
||||||
|
// Password reset token settings
|
||||||
|
passwordResetTokenTTL = 30 * time.Minute
|
||||||
|
|
||||||
|
// Password reset email cooldown (prevent email bombing)
|
||||||
|
passwordResetEmailCooldown = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// SMTPConfig SMTP配置
|
// SMTPConfig SMTP配置
|
||||||
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
|||||||
return ErrVerifyCodeMaxAttempts
|
return ErrVerifyCodeMaxAttempts
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证码不匹配
|
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
|
||||||
if data.Code != code {
|
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||||
data.Attempts++
|
data.Attempts++
|
||||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
||||||
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
|||||||
|
|
||||||
return client.Quit()
|
return client.Quit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
|
||||||
|
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
|
||||||
|
bytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendPasswordResetEmail sends a password reset email with a reset link
|
||||||
|
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
|
||||||
|
var token string
|
||||||
|
var needSaveToken bool
|
||||||
|
|
||||||
|
// Check if token already exists
|
||||||
|
existing, err := s.cache.GetPasswordResetToken(ctx, email)
|
||||||
|
if err == nil && existing != nil {
|
||||||
|
// Token exists, reuse it (allows resending email without generating new token)
|
||||||
|
token = existing.Token
|
||||||
|
needSaveToken = false
|
||||||
|
} else {
|
||||||
|
// Generate new token
|
||||||
|
token, err = s.GeneratePasswordResetToken()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate token: %w", err)
|
||||||
|
}
|
||||||
|
needSaveToken = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save token to Redis (only if new token generated)
|
||||||
|
if needSaveToken {
|
||||||
|
data := &PasswordResetTokenData{
|
||||||
|
Token: token,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
|
||||||
|
return fmt.Errorf("save reset token: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build full reset URL with URL-encoded token and email
|
||||||
|
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
|
||||||
|
|
||||||
|
// Build email content
|
||||||
|
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
|
||||||
|
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
|
||||||
|
|
||||||
|
// Send email
|
||||||
|
if err := s.SendEmail(ctx, email, subject, body); err != nil {
|
||||||
|
return fmt.Errorf("send email: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
|
||||||
|
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
|
||||||
|
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
|
||||||
|
// Check email cooldown to prevent email bombing
|
||||||
|
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
||||||
|
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
|
||||||
|
return nil // Silent success to prevent revealing cooldown to attackers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send email using core method
|
||||||
|
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set cooldown marker (Redis TTL handles expiration)
|
||||||
|
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
|
||||||
|
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyPasswordResetToken verifies the password reset token without consuming it
|
||||||
|
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
|
||||||
|
data, err := s.cache.GetPasswordResetToken(ctx, email)
|
||||||
|
if err != nil || data == nil {
|
||||||
|
return ErrInvalidResetToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use constant-time comparison to prevent timing attacks
|
||||||
|
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
|
||||||
|
return ErrInvalidResetToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
|
||||||
|
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
|
||||||
|
// Verify first
|
||||||
|
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete after verification (one-time use)
|
||||||
|
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
|
||||||
|
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildPasswordResetEmailBody builds the HTML content for password reset email
|
||||||
|
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
|
||||||
|
return fmt.Sprintf(`
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<style>
|
||||||
|
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||||
|
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||||
|
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
|
||||||
|
.header h1 { margin: 0; font-size: 24px; }
|
||||||
|
.content { padding: 40px 30px; text-align: center; }
|
||||||
|
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
|
||||||
|
.button:hover { opacity: 0.9; }
|
||||||
|
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
|
||||||
|
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
|
||||||
|
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||||
|
.warning { color: #e74c3c; font-weight: 500; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="header">
|
||||||
|
<h1>%s</h1>
|
||||||
|
</div>
|
||||||
|
<div class="content">
|
||||||
|
<p style="font-size: 18px; color: #333;">密码重置请求</p>
|
||||||
|
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
|
||||||
|
<a href="%s" class="button">重置密码</a>
|
||||||
|
<div class="info">
|
||||||
|
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
|
||||||
|
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
|
||||||
|
</div>
|
||||||
|
<div class="link-fallback">
|
||||||
|
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
|
||||||
|
<p>%s</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="footer">
|
||||||
|
<p>这是一封自动发送的邮件,请勿回复。</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`, siteName, resetURL, resetURL)
|
||||||
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -99,11 +99,24 @@ var allowedHeaders = map[string]bool{
|
|||||||
"content-type": true,
|
"content-type": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// GatewayCache defines cache operations for gateway service
|
// GatewayCache 定义网关服务的缓存操作接口。
|
||||||
|
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
||||||
|
//
|
||||||
|
// GatewayCache defines cache operations for gateway service.
|
||||||
|
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
||||||
type GatewayCache interface {
|
type GatewayCache interface {
|
||||||
|
// GetSessionAccountID 获取粘性会话绑定的账号 ID
|
||||||
|
// Get the account ID bound to a sticky session
|
||||||
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
|
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
|
||||||
|
// SetSessionAccountID 设置粘性会话与账号的绑定关系
|
||||||
|
// Set the binding between sticky session and account
|
||||||
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
|
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
|
||||||
|
// RefreshSessionTTL 刷新粘性会话的过期时间
|
||||||
|
// Refresh the expiration time of a sticky session
|
||||||
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
|
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
|
||||||
|
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||||
|
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||||
|
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||||
@@ -114,6 +127,28 @@ func derefGroupID(groupID *int64) int64 {
|
|||||||
return *groupID
|
return *groupID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||||
|
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
|
||||||
|
// 这确保后续请求不会继续使用不可用的账号。
|
||||||
|
//
|
||||||
|
// shouldClearStickySession checks if an account is in an unschedulable state
|
||||||
|
// and the sticky session binding should be cleared.
|
||||||
|
// Returns true when account status is error/disabled, schedulable is false,
|
||||||
|
// or within temporary unschedulable period.
|
||||||
|
// This ensures subsequent requests won't continue using unavailable accounts.
|
||||||
|
func shouldClearStickySession(account *Account) bool {
|
||||||
|
if account == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type AccountWaitPlan struct {
|
type AccountWaitPlan struct {
|
||||||
AccountID int64
|
AccountID int64
|
||||||
MaxConcurrency int
|
MaxConcurrency int
|
||||||
@@ -270,6 +305,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
|
|||||||
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
|
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
|
||||||
|
// Returns 0 if no binding exists or on error.
|
||||||
|
func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
|
||||||
|
if sessionHash == "" || s.cache == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return accountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||||
if parsed == nil {
|
if parsed == nil {
|
||||||
return ""
|
return ""
|
||||||
@@ -658,6 +706,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -764,41 +814,52 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, ok := accountByID[accountID]
|
account, ok := accountByID[accountID]
|
||||||
if ok && s.isAccountInGroup(account, groupID) &&
|
if ok {
|
||||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
// 检查账户是否需要清理粘性会话绑定
|
||||||
account.IsSchedulableForModel(requestedModel) &&
|
// Check if the account needs sticky session cleanup
|
||||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
clearSticky := shouldClearStickySession(account)
|
||||||
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
if clearSticky {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && result.Acquired {
|
|
||||||
// 会话数量限制检查
|
|
||||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
|
||||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
|
||||||
} else {
|
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
|
||||||
return &AccountSelectionResult{
|
|
||||||
Account: account,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if !clearSticky && s.isAccountInGroup(account, groupID) &&
|
||||||
|
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||||
|
account.IsSchedulableForModel(requestedModel) &&
|
||||||
|
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||||
|
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
// 会话数量限制检查
|
||||||
|
// Session count limit check
|
||||||
|
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||||
|
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||||
|
} else {
|
||||||
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
// Session count limit check (wait plan also requires session quota)
|
||||||
// 会话限制已满,继续到 Layer 2
|
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||||
} else {
|
// 会话限制已满,继续到 Layer 2
|
||||||
return &AccountSelectionResult{
|
// Session limit full, continue to Layer 2
|
||||||
Account: account,
|
} else {
|
||||||
WaitPlan: &AccountWaitPlan{
|
return &AccountSelectionResult{
|
||||||
AccountID: accountID,
|
Account: account,
|
||||||
MaxConcurrency: account.Concurrency,
|
WaitPlan: &AccountWaitPlan{
|
||||||
Timeout: cfg.StickySessionWaitTimeout,
|
AccountID: accountID,
|
||||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
MaxConcurrency: account.Concurrency,
|
||||||
},
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
}, nil
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1418,14 +1479,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
clearSticky := shouldClearStickySession(account)
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
if clearSticky {
|
||||||
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
}
|
}
|
||||||
if s.debugModelRoutingEnabled() {
|
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||||
|
}
|
||||||
|
return account, nil
|
||||||
}
|
}
|
||||||
return account, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1515,11 +1582,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
clearSticky := shouldClearStickySession(account)
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
if clearSticky {
|
||||||
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
}
|
||||||
|
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
|
return account, nil
|
||||||
}
|
}
|
||||||
return account, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1619,15 +1692,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil {
|
||||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
clearSticky := shouldClearStickySession(account)
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
if clearSticky {
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
}
|
||||||
|
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||||
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
|
if s.debugModelRoutingEnabled() {
|
||||||
|
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||||
|
}
|
||||||
|
return account, nil
|
||||||
}
|
}
|
||||||
if s.debugModelRoutingEnabled() {
|
|
||||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
|
||||||
}
|
|
||||||
return account, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1718,12 +1797,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil {
|
||||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
clearSticky := shouldClearStickySession(account)
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
if clearSticky {
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
}
|
||||||
|
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||||
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
|
return account, nil
|
||||||
}
|
}
|
||||||
return account, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3287,17 +3372,19 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
|||||||
} `json:"usage"`
|
} `json:"usage"`
|
||||||
}
|
}
|
||||||
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
|
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
|
||||||
// output_tokens 总是从 message_delta 获取
|
// message_delta 仅覆盖存在且非0的字段
|
||||||
usage.OutputTokens = msgDelta.Usage.OutputTokens
|
// 避免覆盖 message_start 中已有的值(如 input_tokens)
|
||||||
|
// Claude API 的 message_delta 通常只包含 output_tokens
|
||||||
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
|
if msgDelta.Usage.InputTokens > 0 {
|
||||||
if usage.InputTokens == 0 {
|
|
||||||
usage.InputTokens = msgDelta.Usage.InputTokens
|
usage.InputTokens = msgDelta.Usage.InputTokens
|
||||||
}
|
}
|
||||||
if usage.CacheCreationInputTokens == 0 {
|
if msgDelta.Usage.OutputTokens > 0 {
|
||||||
|
usage.OutputTokens = msgDelta.Usage.OutputTokens
|
||||||
|
}
|
||||||
|
if msgDelta.Usage.CacheCreationInputTokens > 0 {
|
||||||
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
|
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
|
||||||
}
|
}
|
||||||
if usage.CacheReadInputTokens == 0 {
|
if msgDelta.Usage.CacheReadInputTokens > 0 {
|
||||||
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,70 +82,23 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
// 1. 确定目标平台和调度模式
|
||||||
var platform string
|
// Determine target platform and scheduling mode
|
||||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
|
||||||
if hasForcePlatform && forcePlatform != "" {
|
if err != nil {
|
||||||
platform = forcePlatform
|
return nil, err
|
||||||
} else if groupID != nil {
|
|
||||||
// 根据分组 platform 决定查询哪种账号
|
|
||||||
var group *Group
|
|
||||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
|
||||||
group = ctxGroup
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get group failed: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
platform = group.Platform
|
|
||||||
} else {
|
|
||||||
// 无分组时只使用原生 gemini 平台
|
|
||||||
platform = PlatformGemini
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
|
||||||
// 注意:强制平台模式不走混合调度
|
|
||||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
|
||||||
|
|
||||||
cacheKey := "gemini:" + sessionHash
|
cacheKey := "gemini:" + sessionHash
|
||||||
|
|
||||||
if sessionHash != "" {
|
// 2. 尝试粘性会话命中
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
// Try sticky session hit
|
||||||
if err == nil && accountID > 0 {
|
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
return account, nil
|
||||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
|
||||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
|
||||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
|
||||||
valid := false
|
|
||||||
if account.Platform == platform {
|
|
||||||
valid = true
|
|
||||||
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
|
||||||
valid = true
|
|
||||||
}
|
|
||||||
if valid {
|
|
||||||
usable := true
|
|
||||||
if s.rateLimitService != nil && requestedModel != "" {
|
|
||||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
usable = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if usable {
|
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||||
|
// Query schedulable accounts (force platform mode: try group first, fallback to all)
|
||||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
@@ -158,56 +111,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var selected *Account
|
// 4. 按优先级 + LRU 选择最佳账号
|
||||||
for i := range accounts {
|
// Select best account by priority + LRU
|
||||||
acc := &accounts[i]
|
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
|
||||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
|
|
||||||
// 非混合调度模式(antigravity 分组):不需要过滤
|
|
||||||
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !acc.IsSchedulableForModel(requestedModel) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if s.rateLimitService != nil && requestedModel != "" {
|
|
||||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if selected == nil {
|
|
||||||
selected = acc
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if acc.Priority < selected.Priority {
|
|
||||||
selected = acc
|
|
||||||
} else if acc.Priority == selected.Priority {
|
|
||||||
switch {
|
|
||||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
|
||||||
selected = acc
|
|
||||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
|
||||||
// keep selected (never used is preferred)
|
|
||||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
|
||||||
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
|
|
||||||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
|
||||||
selected = acc
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
|
||||||
selected = acc
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if selected == nil {
|
if selected == nil {
|
||||||
if requestedModel != "" {
|
if requestedModel != "" {
|
||||||
@@ -216,6 +122,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
return nil, errors.New("no available Gemini accounts")
|
return nil, errors.New("no available Gemini accounts")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 5. 设置粘性会话绑定
|
||||||
|
// Set sticky session binding
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
|
||||||
}
|
}
|
||||||
@@ -223,6 +131,229 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
return selected, nil
|
return selected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
|
||||||
|
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
|
||||||
|
//
|
||||||
|
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
|
||||||
|
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
|
||||||
|
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
|
||||||
|
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||||
|
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||||
|
if hasForcePlatform && forcePlatform != "" {
|
||||||
|
return forcePlatform, false, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if groupID != nil {
|
||||||
|
// 根据分组 platform 决定查询哪种账号
|
||||||
|
var group *Group
|
||||||
|
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||||
|
group = ctxGroup
|
||||||
|
} else {
|
||||||
|
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, false, fmt.Errorf("get group failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
|
return group.Platform, group.Platform == PlatformGemini, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 无分组时只使用原生 gemini 平台
|
||||||
|
return PlatformGemini, true, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||||
|
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
||||||
|
//
|
||||||
|
// tryStickySessionHit attempts to get account from sticky session.
|
||||||
|
// Returns account if hit and usable; clears session and returns nil if account unavailable.
|
||||||
|
func (s *GeminiMessagesCompatService) tryStickySessionHit(
|
||||||
|
ctx context.Context,
|
||||||
|
groupID *int64,
|
||||||
|
sessionHash, cacheKey, requestedModel string,
|
||||||
|
excludedIDs map[int64]struct{},
|
||||||
|
platform string,
|
||||||
|
useMixedScheduling bool,
|
||||||
|
) *Account {
|
||||||
|
if sessionHash == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||||
|
if err != nil || accountID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, excluded := excludedIDs[accountID]; excluded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查账号是否需要清理粘性会话
|
||||||
|
// Check if sticky session should be cleared
|
||||||
|
if shouldClearStickySession(account) {
|
||||||
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证账号是否可用于当前请求
|
||||||
|
// Verify account is usable for current request
|
||||||
|
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新会话 TTL 并返回账号
|
||||||
|
// Refresh session TTL and return account
|
||||||
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||||
|
return account
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAccountUsableForRequest 检查账号是否可用于当前请求。
|
||||||
|
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
|
||||||
|
//
|
||||||
|
// isAccountUsableForRequest checks if account is usable for current request.
|
||||||
|
// Validates: model scheduling, model support, platform matching, rate limit precheck.
|
||||||
|
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
requestedModel, platform string,
|
||||||
|
useMixedScheduling bool,
|
||||||
|
) bool {
|
||||||
|
// 检查模型调度能力
|
||||||
|
// Check model scheduling capability
|
||||||
|
if !account.IsSchedulableForModel(requestedModel) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查模型支持
|
||||||
|
// Check model support
|
||||||
|
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查平台匹配
|
||||||
|
// Check platform matching
|
||||||
|
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 速率限制预检
|
||||||
|
// Rate limit precheck
|
||||||
|
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAccountValidForPlatform 检查账号是否匹配目标平台。
|
||||||
|
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
|
||||||
|
//
|
||||||
|
// isAccountValidForPlatform checks if account matches target platform.
|
||||||
|
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
|
||||||
|
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
|
||||||
|
if account.Platform == platform {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// passesRateLimitPreCheck 执行速率限制预检。
|
||||||
|
// 返回 true 表示通过预检或无需预检。
|
||||||
|
//
|
||||||
|
// passesRateLimitPreCheck performs rate limit precheck.
|
||||||
|
// Returns true if passed or precheck not required.
|
||||||
|
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
|
||||||
|
if s.rateLimitService == nil || requestedModel == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
|
||||||
|
// 返回 nil 表示无可用账号。
|
||||||
|
//
|
||||||
|
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
|
||||||
|
// Returns nil if no available account.
|
||||||
|
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||||
|
ctx context.Context,
|
||||||
|
accounts []Account,
|
||||||
|
requestedModel string,
|
||||||
|
excludedIDs map[int64]struct{},
|
||||||
|
platform string,
|
||||||
|
useMixedScheduling bool,
|
||||||
|
) *Account {
|
||||||
|
var selected *Account
|
||||||
|
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
|
||||||
|
// 跳过被排除的账号
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查账号是否可用于当前请求
|
||||||
|
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 选择最佳账号
|
||||||
|
if selected == nil {
|
||||||
|
selected = acc
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.isBetterGeminiAccount(acc, selected) {
|
||||||
|
selected = acc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
|
||||||
|
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
|
||||||
|
//
|
||||||
|
// isBetterGeminiAccount checks if candidate is better than current.
|
||||||
|
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
|
||||||
|
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
|
||||||
|
// 优先级更高(数值更小)
|
||||||
|
if candidate.Priority < current.Priority {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if candidate.Priority > current.Priority {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同优先级,比较最后使用时间
|
||||||
|
switch {
|
||||||
|
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
||||||
|
// candidate 从未使用,优先
|
||||||
|
return true
|
||||||
|
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
||||||
|
// current 从未使用,保持
|
||||||
|
return false
|
||||||
|
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
||||||
|
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
|
||||||
|
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
|
||||||
|
default:
|
||||||
|
// 都使用过,选择最久未使用的
|
||||||
|
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||||||
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
@@ -800,6 +931,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 图片生成计费
|
||||||
|
imageCount := 0
|
||||||
|
imageSize := s.extractImageSize(body)
|
||||||
|
if isImageGenerationModel(originalModel) {
|
||||||
|
imageCount = 1
|
||||||
|
}
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
@@ -807,6 +945,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
Stream: req.Stream,
|
Stream: req.Stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
ImageCount: imageCount,
|
||||||
|
ImageSize: imageSize,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1240,6 +1380,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
usage = &ClaudeUsage{}
|
usage = &ClaudeUsage{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 图片生成计费
|
||||||
|
imageCount := 0
|
||||||
|
imageSize := s.extractImageSize(body)
|
||||||
|
if isImageGenerationModel(originalModel) {
|
||||||
|
imageCount = 1
|
||||||
|
}
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
@@ -1247,6 +1394,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
Stream: stream,
|
Stream: stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
ImageCount: imageCount,
|
||||||
|
ImageSize: imageSize,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1841,6 +1990,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
|||||||
|
|
||||||
var last map[string]any
|
var last map[string]any
|
||||||
var lastWithParts map[string]any
|
var lastWithParts map[string]any
|
||||||
|
var collectedTextParts []string // Collect all text parts for aggregation
|
||||||
usage := &ClaudeUsage{}
|
usage := &ClaudeUsage{}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -1852,7 +2002,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
|||||||
switch payload {
|
switch payload {
|
||||||
case "", "[DONE]":
|
case "", "[DONE]":
|
||||||
if payload == "[DONE]" {
|
if payload == "[DONE]" {
|
||||||
return pickGeminiCollectResult(last, lastWithParts), usage, nil
|
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
var parsed map[string]any
|
var parsed map[string]any
|
||||||
@@ -1871,6 +2021,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
|||||||
}
|
}
|
||||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||||
lastWithParts = parsed
|
lastWithParts = parsed
|
||||||
|
// Collect text from each part for aggregation
|
||||||
|
for _, part := range parts {
|
||||||
|
if text, ok := part["text"].(string); ok && text != "" {
|
||||||
|
collectedTextParts = append(collectedTextParts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1885,7 +2041,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return pickGeminiCollectResult(last, lastWithParts), usage, nil
|
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
|
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
|
||||||
@@ -1898,6 +2054,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
|
|||||||
return map[string]any{}
|
return map[string]any{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeCollectedTextParts merges all collected text chunks into the final response.
|
||||||
|
// This fixes the issue where non-streaming responses only returned the last chunk
|
||||||
|
// instead of the complete aggregated text.
|
||||||
|
func mergeCollectedTextParts(response map[string]any, textParts []string) map[string]any {
|
||||||
|
if len(textParts) == 0 {
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join all text parts
|
||||||
|
mergedText := strings.Join(textParts, "")
|
||||||
|
|
||||||
|
// Deep copy response
|
||||||
|
result := make(map[string]any)
|
||||||
|
for k, v := range response {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get or create candidates
|
||||||
|
candidates, ok := result["candidates"].([]any)
|
||||||
|
if !ok || len(candidates) == 0 {
|
||||||
|
candidates = []any{map[string]any{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get first candidate
|
||||||
|
candidate, ok := candidates[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
candidate = make(map[string]any)
|
||||||
|
candidates[0] = candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get or create content
|
||||||
|
content, ok := candidate["content"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
content = map[string]any{"role": "model"}
|
||||||
|
candidate["content"] = content
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing parts
|
||||||
|
existingParts, ok := content["parts"].([]any)
|
||||||
|
if !ok {
|
||||||
|
existingParts = []any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find and update first text part, or create new one
|
||||||
|
newParts := make([]any, 0, len(existingParts)+1)
|
||||||
|
textUpdated := false
|
||||||
|
|
||||||
|
for _, p := range existingParts {
|
||||||
|
pm, ok := p.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
newParts = append(newParts, p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, hasText := pm["text"]; hasText && !textUpdated {
|
||||||
|
// Replace with merged text
|
||||||
|
newPart := make(map[string]any)
|
||||||
|
for k, v := range pm {
|
||||||
|
newPart[k] = v
|
||||||
|
}
|
||||||
|
newPart["text"] = mergedText
|
||||||
|
newParts = append(newParts, newPart)
|
||||||
|
textUpdated = true
|
||||||
|
} else {
|
||||||
|
newParts = append(newParts, pm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !textUpdated {
|
||||||
|
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
content["parts"] = newParts
|
||||||
|
result["candidates"] = candidates
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
type geminiNativeStreamResult struct {
|
type geminiNativeStreamResult struct {
|
||||||
usage *ClaudeUsage
|
usage *ClaudeUsage
|
||||||
firstTokenMs *int
|
firstTokenMs *int
|
||||||
@@ -2816,3 +3049,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
|
|||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||||||
|
func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
|
||||||
|
var req struct {
|
||||||
|
GenerationConfig *struct {
|
||||||
|
ImageConfig *struct {
|
||||||
|
ImageSize string `json:"imageSize"`
|
||||||
|
} `json:"imageConfig"`
|
||||||
|
} `json:"generationConfig"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return "2K"
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
|
||||||
|
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
|
||||||
|
if size == "1K" || size == "2K" || size == "4K" {
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "2K"
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,8 +15,10 @@ import (
|
|||||||
|
|
||||||
// mockAccountRepoForGemini Gemini 测试用的 mock
|
// mockAccountRepoForGemini Gemini 测试用的 mock
|
||||||
type mockAccountRepoForGemini struct {
|
type mockAccountRepoForGemini struct {
|
||||||
accounts []Account
|
accounts []Account
|
||||||
accountsByID map[int64]*Account
|
accountsByID map[int64]*Account
|
||||||
|
listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||||
|
listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
|
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||||
@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||||
|
if m.listByPlatformFunc != nil {
|
||||||
|
return m.listByPlatformFunc(ctx, platforms)
|
||||||
|
}
|
||||||
var result []Account
|
var result []Account
|
||||||
platformSet := make(map[string]bool)
|
platformSet := make(map[string]bool)
|
||||||
for _, p := range platforms {
|
for _, p := range platforms {
|
||||||
@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||||
|
if m.listByGroupFunc != nil {
|
||||||
|
return m.listByGroupFunc(ctx, groupID, platforms)
|
||||||
|
}
|
||||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
}
|
}
|
||||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
|||||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||||
type mockGatewayCacheForGemini struct {
|
type mockGatewayCacheForGemini struct {
|
||||||
sessionBindings map[string]int64
|
sessionBindings map[string]int64
|
||||||
|
deletedSessions map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||||
@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||||
|
if m.sessionBindings == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if m.deletedSessions == nil {
|
||||||
|
m.deletedSessions = make(map[string]int)
|
||||||
|
}
|
||||||
|
m.deletedSessions[sessionHash]++
|
||||||
|
delete(m.sessionBindings, sessionHash)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
|
|||||||
// 粘性会话未命中,按优先级选择
|
// 粘性会话未命中,按优先级选择
|
||||||
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
|
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
|
||||||
|
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{
|
||||||
|
sessionBindings: map[string]int64{"gemini:session-123": 1},
|
||||||
|
}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID)
|
||||||
|
require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
|
||||||
|
require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
groupID := int64(9)
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
|
||||||
|
return []Account{
|
||||||
|
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{
|
||||||
|
1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(1), acc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
accounts: []Account{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Priority: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, acc)
|
||||||
|
require.Contains(t, err.Error(), "supporting model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
|
||||||
|
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{
|
||||||
|
sessionBindings: map[string]int64{"gemini:session-999": 1},
|
||||||
|
}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(1), acc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||||
|
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||||
|
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
excluded := map[int64]struct{}{1: {}}
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
|
||||||
|
return nil, errors.New("query failed")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, acc)
|
||||||
|
require.Contains(t, err.Error(), "query accounts failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||||
|
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
oldTime := time.Now().Add(-2 * time.Hour)
|
||||||
|
newTime := time.Now().Add(-1 * time.Hour)
|
||||||
|
repo := &mockAccountRepoForGemini{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
|
||||||
|
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
|
||||||
|
},
|
||||||
|
accountsByID: map[int64]*Account{},
|
||||||
|
}
|
||||||
|
for i := range repo.accounts {
|
||||||
|
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &mockGatewayCacheForGemini{}
|
||||||
|
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
accountRepo: repo,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, acc)
|
||||||
|
require.Equal(t, int64(2), acc.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
|
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
|
||||||
|
|||||||
72
backend/internal/service/gemini_native_signature_cleaner.go
Normal file
72
backend/internal/service/gemini_native_signature_cleaner.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
|
||||||
|
// 以避免跨账号签名验证错误。
|
||||||
|
//
|
||||||
|
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
|
||||||
|
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
|
||||||
|
//
|
||||||
|
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
|
||||||
|
// to avoid cross-account signature validation errors.
|
||||||
|
//
|
||||||
|
// When sticky session switches accounts (e.g., original account becomes unavailable),
|
||||||
|
// thoughtSignatures from the old account will cause validation failures on the new account.
|
||||||
|
// By removing these signatures, we allow the new account to generate valid signatures.
|
||||||
|
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 JSON
|
||||||
|
var data any
|
||||||
|
if err := json.Unmarshal(body, &data); err != nil {
|
||||||
|
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 递归清理 thoughtSignature
|
||||||
|
cleaned := cleanThoughtSignaturesRecursive(data)
|
||||||
|
|
||||||
|
// 重新序列化
|
||||||
|
result, err := json.Marshal(cleaned)
|
||||||
|
if err != nil {
|
||||||
|
// 如果序列化失败,返回原始 body
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
|
||||||
|
func cleanThoughtSignaturesRecursive(data any) any {
|
||||||
|
switch v := data.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
// 创建新的 map,移除 thoughtSignature
|
||||||
|
result := make(map[string]any, len(v))
|
||||||
|
for key, value := range v {
|
||||||
|
// 跳过 thoughtSignature 字段
|
||||||
|
if key == "thoughtSignature" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 递归处理嵌套结构
|
||||||
|
result[key] = cleanThoughtSignaturesRecursive(value)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// 递归处理数组中的每个元素
|
||||||
|
result := make([]any, len(v))
|
||||||
|
for i, item := range v {
|
||||||
|
result[i] = cleanThoughtSignaturesRecursive(item)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
default:
|
||||||
|
// 基本类型(string, number, bool, null)直接返回
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -131,21 +132,32 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) Populate cache with TTL.
|
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||||
if p.tokenCache != nil {
|
if p.tokenCache != nil {
|
||||||
ttl := 30 * time.Minute
|
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||||
if expiresAt != nil {
|
if isStale && latestAccount != nil {
|
||||||
until := time.Until(*expiresAt)
|
// 版本过时,使用 DB 中的最新 token
|
||||||
switch {
|
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
||||||
case until > geminiTokenCacheSkew:
|
accessToken = latestAccount.GetCredential("access_token")
|
||||||
ttl = until - geminiTokenCacheSkew
|
if strings.TrimSpace(accessToken) == "" {
|
||||||
case until > 0:
|
return "", errors.New("access_token not found after version check")
|
||||||
ttl = until
|
|
||||||
default:
|
|
||||||
ttl = time.Minute
|
|
||||||
}
|
}
|
||||||
|
// 不写入缓存,让下次请求重新处理
|
||||||
|
} else {
|
||||||
|
ttl := 30 * time.Minute
|
||||||
|
if expiresAt != nil {
|
||||||
|
until := time.Until(*expiresAt)
|
||||||
|
switch {
|
||||||
|
case until > geminiTokenCacheSkew:
|
||||||
|
ttl = until - geminiTokenCacheSkew
|
||||||
|
case until > 0:
|
||||||
|
ttl = until
|
||||||
|
default:
|
||||||
|
ttl = time.Minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||||
}
|
}
|
||||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
|
|||||||
@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
|
|||||||
|
|
||||||
// GenerateAuthURL generates an OAuth authorization URL with full scope
|
// GenerateAuthURL generates an OAuth authorization URL with full scope
|
||||||
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
return s.generateAuthURLWithScope(ctx, oauth.ScopeOAuth, proxyID)
|
||||||
return s.generateAuthURLWithScope(ctx, scope, proxyID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
|
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
|
||||||
@@ -123,6 +122,7 @@ type TokenInfo struct {
|
|||||||
Scope string `json:"scope,omitempty"`
|
Scope string `json:"scope,omitempty"`
|
||||||
OrgUUID string `json:"org_uuid,omitempty"`
|
OrgUUID string `json:"org_uuid,omitempty"`
|
||||||
AccountUUID string `json:"account_uuid,omitempty"`
|
AccountUUID string `json:"account_uuid,omitempty"`
|
||||||
|
EmailAddress string `json:"email_address,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeCode exchanges authorization code for tokens
|
// ExchangeCode exchanges authorization code for tokens
|
||||||
@@ -176,7 +176,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Determine scope and if this is a setup token
|
// Determine scope and if this is a setup token
|
||||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
// Internal API call uses ScopeAPI (org:create_api_key not supported)
|
||||||
|
scope := oauth.ScopeAPI
|
||||||
isSetupToken := false
|
isSetupToken := false
|
||||||
if input.Scope == "inference" {
|
if input.Scope == "inference" {
|
||||||
scope = oauth.ScopeInference
|
scope = oauth.ScopeInference
|
||||||
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
|||||||
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
||||||
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
||||||
}
|
}
|
||||||
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
|
if tokenResp.Account != nil {
|
||||||
tokenInfo.AccountUUID = tokenResp.Account.UUID
|
if tokenResp.Account.UUID != "" {
|
||||||
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
|
tokenInfo.AccountUUID = tokenResp.Account.UUID
|
||||||
|
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
|
||||||
|
}
|
||||||
|
if tokenResp.Account.EmailAddress != "" {
|
||||||
|
tokenInfo.EmailAddress = tokenResp.Account.EmailAddress
|
||||||
|
log.Printf("[OAuth] Got email_address: %s", tokenInfo.EmailAddress)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenInfo, nil
|
return tokenInfo, nil
|
||||||
|
|||||||
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
|
|||||||
UpdatedAt string `json:"updated_at,omitempty"`
|
UpdatedAt string `json:"updated_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
|
||||||
|
type NormalizedCodexLimits struct {
|
||||||
|
Used5hPercent *float64
|
||||||
|
Reset5hSeconds *int
|
||||||
|
Window5hMinutes *int
|
||||||
|
Used7dPercent *float64
|
||||||
|
Reset7dSeconds *int
|
||||||
|
Window7dMinutes *int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
|
||||||
|
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
|
||||||
|
// Returns nil if snapshot is nil or has no useful data.
|
||||||
|
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &NormalizedCodexLimits{}
|
||||||
|
|
||||||
|
primaryMins := 0
|
||||||
|
secondaryMins := 0
|
||||||
|
hasPrimaryWindow := false
|
||||||
|
hasSecondaryWindow := false
|
||||||
|
|
||||||
|
if s.PrimaryWindowMinutes != nil {
|
||||||
|
primaryMins = *s.PrimaryWindowMinutes
|
||||||
|
hasPrimaryWindow = true
|
||||||
|
}
|
||||||
|
if s.SecondaryWindowMinutes != nil {
|
||||||
|
secondaryMins = *s.SecondaryWindowMinutes
|
||||||
|
hasSecondaryWindow = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine mapping based on window_minutes
|
||||||
|
use5hFromPrimary := false
|
||||||
|
use7dFromPrimary := false
|
||||||
|
|
||||||
|
if hasPrimaryWindow && hasSecondaryWindow {
|
||||||
|
// Both known: smaller window is 5h, larger is 7d
|
||||||
|
if primaryMins < secondaryMins {
|
||||||
|
use5hFromPrimary = true
|
||||||
|
} else {
|
||||||
|
use7dFromPrimary = true
|
||||||
|
}
|
||||||
|
} else if hasPrimaryWindow {
|
||||||
|
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
|
||||||
|
if primaryMins <= 360 {
|
||||||
|
use5hFromPrimary = true
|
||||||
|
} else {
|
||||||
|
use7dFromPrimary = true
|
||||||
|
}
|
||||||
|
} else if hasSecondaryWindow {
|
||||||
|
// Only secondary known: classify by threshold
|
||||||
|
if secondaryMins <= 360 {
|
||||||
|
// 5h from secondary, so primary (if any data) is 7d
|
||||||
|
use7dFromPrimary = true
|
||||||
|
} else {
|
||||||
|
// 7d from secondary, so primary (if any data) is 5h
|
||||||
|
use5hFromPrimary = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
|
||||||
|
use7dFromPrimary = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign values
|
||||||
|
if use5hFromPrimary {
|
||||||
|
result.Used5hPercent = s.PrimaryUsedPercent
|
||||||
|
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
|
||||||
|
result.Window5hMinutes = s.PrimaryWindowMinutes
|
||||||
|
result.Used7dPercent = s.SecondaryUsedPercent
|
||||||
|
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
|
||||||
|
result.Window7dMinutes = s.SecondaryWindowMinutes
|
||||||
|
} else if use7dFromPrimary {
|
||||||
|
result.Used7dPercent = s.PrimaryUsedPercent
|
||||||
|
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
|
||||||
|
result.Window7dMinutes = s.PrimaryWindowMinutes
|
||||||
|
result.Used5hPercent = s.SecondaryUsedPercent
|
||||||
|
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
|
||||||
|
result.Window5hMinutes = s.SecondaryWindowMinutes
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAIUsage represents OpenAI API response usage
|
// OpenAIUsage represents OpenAI API response usage
|
||||||
type OpenAIUsage struct {
|
type OpenAIUsage struct {
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
@@ -180,67 +266,26 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||||
|
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
|
||||||
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
// 1. Check sticky session
|
cacheKey := "openai:" + sessionHash
|
||||||
if sessionHash != "" {
|
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
// 1. 尝试粘性会话命中
|
||||||
if err == nil && accountID > 0 {
|
// Try sticky session hit
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
|
||||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
return account, nil
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
|
||||||
// Refresh sticky session TTL
|
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Get schedulable OpenAI accounts
|
// 2. 获取可调度的 OpenAI 账号
|
||||||
|
// Get schedulable OpenAI accounts
|
||||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Select by priority + LRU
|
// 3. 按优先级 + LRU 选择最佳账号
|
||||||
var selected *Account
|
// Select by priority + LRU
|
||||||
for i := range accounts {
|
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
|
||||||
acc := &accounts[i]
|
|
||||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
|
||||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
|
||||||
if !acc.IsSchedulable() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Check model support
|
|
||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if selected == nil {
|
|
||||||
selected = acc
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Lower priority value means higher priority
|
|
||||||
if acc.Priority < selected.Priority {
|
|
||||||
selected = acc
|
|
||||||
} else if acc.Priority == selected.Priority {
|
|
||||||
switch {
|
|
||||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
|
||||||
selected = acc
|
|
||||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
|
||||||
// keep selected (never used is preferred)
|
|
||||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
|
||||||
// keep selected (both never used)
|
|
||||||
default:
|
|
||||||
// Same priority, select least recently used
|
|
||||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
|
||||||
selected = acc
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if selected == nil {
|
if selected == nil {
|
||||||
if requestedModel != "" {
|
if requestedModel != "" {
|
||||||
@@ -249,14 +294,138 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
return nil, errors.New("no available OpenAI accounts")
|
return nil, errors.New("no available OpenAI accounts")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Set sticky session
|
// 4. 设置粘性会话绑定
|
||||||
|
// Set sticky session binding
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return selected, nil
|
return selected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||||
|
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
||||||
|
//
|
||||||
|
// tryStickySessionHit attempts to get account from sticky session.
|
||||||
|
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
|
||||||
|
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||||
|
if sessionHash == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||||
|
if err != nil || accountID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, excluded := excludedIDs[accountID]; excluded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查账号是否需要清理粘性会话
|
||||||
|
// Check if sticky session should be cleared
|
||||||
|
if shouldClearStickySession(account) {
|
||||||
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证账号是否可用于当前请求
|
||||||
|
// Verify account is usable for current request
|
||||||
|
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新会话 TTL 并返回账号
|
||||||
|
// Refresh session TTL and return account
|
||||||
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
|
||||||
|
return account
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
|
||||||
|
// 返回 nil 表示无可用账号。
|
||||||
|
//
|
||||||
|
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||||
|
// Returns nil if no available account.
|
||||||
|
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||||
|
var selected *Account
|
||||||
|
|
||||||
|
for i := range accounts {
|
||||||
|
acc := &accounts[i]
|
||||||
|
|
||||||
|
// 跳过被排除的账号
|
||||||
|
// Skip excluded accounts
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
|
||||||
|
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
|
||||||
|
if !acc.IsSchedulable() || !acc.IsOpenAI() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查模型支持
|
||||||
|
// Check model support
|
||||||
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 选择优先级最高且最久未使用的账号
|
||||||
|
// Select highest priority and least recently used
|
||||||
|
if selected == nil {
|
||||||
|
selected = acc
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.isBetterAccount(acc, selected) {
|
||||||
|
selected = acc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBetterAccount 判断 candidate 是否比 current 更优。
|
||||||
|
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
|
||||||
|
//
|
||||||
|
// isBetterAccount checks if candidate is better than current.
|
||||||
|
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
|
||||||
|
func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
|
||||||
|
// 优先级更高(数值更小)
|
||||||
|
// Higher priority (lower value)
|
||||||
|
if candidate.Priority < current.Priority {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if candidate.Priority > current.Priority {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同优先级,比较最后使用时间
|
||||||
|
// Same priority, compare last used time
|
||||||
|
switch {
|
||||||
|
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
||||||
|
// candidate 从未使用,优先
|
||||||
|
return true
|
||||||
|
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
||||||
|
// current 从未使用,保持
|
||||||
|
return false
|
||||||
|
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
||||||
|
// 都未使用,保持
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
// 都使用过,选择最久未使用的
|
||||||
|
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||||
cfg := s.schedulingConfig()
|
cfg := s.schedulingConfig()
|
||||||
@@ -325,29 +494,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
if err == nil {
|
||||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
clearSticky := shouldClearStickySession(account)
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
if clearSticky {
|
||||||
if err == nil && result.Acquired {
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
|
||||||
return &AccountSelectionResult{
|
|
||||||
Account: account,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
|
||||||
|
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
|
if err == nil && result.Acquired {
|
||||||
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: account,
|
||||||
|
Acquired: true,
|
||||||
|
ReleaseFunc: result.ReleaseFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: account,
|
Account: account,
|
||||||
WaitPlan: &AccountWaitPlan{
|
WaitPlan: &AccountWaitPlan{
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
MaxConcurrency: account.Concurrency,
|
MaxConcurrency: account.Concurrency,
|
||||||
Timeout: cfg.StickySessionWaitTimeout,
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -778,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
|
|
||||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||||
if account.Type == AccountTypeOAuth {
|
if account.Type == AccountTypeOAuth {
|
||||||
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1576,8 +1751,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractCodexUsageHeaders extracts Codex usage limits from response headers
|
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
||||||
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
||||||
|
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||||
snapshot := &OpenAICodexUsageSnapshot{}
|
snapshot := &OpenAICodexUsageSnapshot{}
|
||||||
hasData := false
|
hasData := false
|
||||||
|
|
||||||
@@ -1651,6 +1827,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
|||||||
|
|
||||||
// Convert snapshot to map for merging into Extra
|
// Convert snapshot to map for merging into Extra
|
||||||
updates := make(map[string]any)
|
updates := make(map[string]any)
|
||||||
|
|
||||||
|
// Save raw primary/secondary fields for debugging/tracing
|
||||||
if snapshot.PrimaryUsedPercent != nil {
|
if snapshot.PrimaryUsedPercent != nil {
|
||||||
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||||
}
|
}
|
||||||
@@ -1674,109 +1852,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
|||||||
}
|
}
|
||||||
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
||||||
|
|
||||||
// Normalize to canonical 5h/7d fields based on window_minutes
|
// Normalize to canonical 5h/7d fields
|
||||||
// This fixes the issue where OpenAI's primary/secondary naming is reversed
|
if normalized := snapshot.Normalize(); normalized != nil {
|
||||||
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
|
if normalized.Used5hPercent != nil {
|
||||||
|
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
|
||||||
// IMPORTANT: We can only reliably determine window type from window_minutes field
|
|
||||||
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
|
|
||||||
|
|
||||||
var primaryWindowMins, secondaryWindowMins int
|
|
||||||
var hasPrimaryWindow, hasSecondaryWindow bool
|
|
||||||
|
|
||||||
// Only use window_minutes for reliable window size comparison
|
|
||||||
if snapshot.PrimaryWindowMinutes != nil {
|
|
||||||
primaryWindowMins = *snapshot.PrimaryWindowMinutes
|
|
||||||
hasPrimaryWindow = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if snapshot.SecondaryWindowMinutes != nil {
|
|
||||||
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
|
|
||||||
hasSecondaryWindow = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine which is 5h and which is 7d
|
|
||||||
var use5hFromPrimary, use7dFromPrimary bool
|
|
||||||
var use5hFromSecondary, use7dFromSecondary bool
|
|
||||||
|
|
||||||
if hasPrimaryWindow && hasSecondaryWindow {
|
|
||||||
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
|
|
||||||
if primaryWindowMins < secondaryWindowMins {
|
|
||||||
use5hFromPrimary = true
|
|
||||||
use7dFromSecondary = true
|
|
||||||
} else {
|
|
||||||
use5hFromSecondary = true
|
|
||||||
use7dFromPrimary = true
|
|
||||||
}
|
}
|
||||||
} else if hasPrimaryWindow {
|
if normalized.Reset5hSeconds != nil {
|
||||||
// Only primary window size known: classify by absolute threshold
|
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
|
||||||
if primaryWindowMins <= 360 {
|
|
||||||
use5hFromPrimary = true
|
|
||||||
} else {
|
|
||||||
use7dFromPrimary = true
|
|
||||||
}
|
}
|
||||||
} else if hasSecondaryWindow {
|
if normalized.Window5hMinutes != nil {
|
||||||
// Only secondary window size known: classify by absolute threshold
|
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
|
||||||
if secondaryWindowMins <= 360 {
|
|
||||||
use5hFromSecondary = true
|
|
||||||
} else {
|
|
||||||
use7dFromSecondary = true
|
|
||||||
}
|
}
|
||||||
} else {
|
if normalized.Used7dPercent != nil {
|
||||||
// No window_minutes available: cannot reliably determine window types
|
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
|
||||||
// Fall back to legacy assumption (may be incorrect)
|
|
||||||
// Assume primary=7d, secondary=5h based on historical observation
|
|
||||||
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
|
|
||||||
use5hFromSecondary = true
|
|
||||||
}
|
}
|
||||||
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
|
if normalized.Reset7dSeconds != nil {
|
||||||
use7dFromPrimary = true
|
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
|
||||||
}
|
}
|
||||||
}
|
if normalized.Window7dMinutes != nil {
|
||||||
|
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
|
||||||
// Write canonical 5h fields
|
|
||||||
if use5hFromPrimary {
|
|
||||||
if snapshot.PrimaryUsedPercent != nil {
|
|
||||||
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
|
|
||||||
}
|
|
||||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
|
||||||
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
|
||||||
}
|
|
||||||
if snapshot.PrimaryWindowMinutes != nil {
|
|
||||||
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
|
||||||
}
|
|
||||||
} else if use5hFromSecondary {
|
|
||||||
if snapshot.SecondaryUsedPercent != nil {
|
|
||||||
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
|
|
||||||
}
|
|
||||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
|
||||||
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
|
||||||
}
|
|
||||||
if snapshot.SecondaryWindowMinutes != nil {
|
|
||||||
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write canonical 7d fields
|
|
||||||
if use7dFromPrimary {
|
|
||||||
if snapshot.PrimaryUsedPercent != nil {
|
|
||||||
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
|
|
||||||
}
|
|
||||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
|
||||||
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
|
||||||
}
|
|
||||||
if snapshot.PrimaryWindowMinutes != nil {
|
|
||||||
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
|
||||||
}
|
|
||||||
} else if use7dFromSecondary {
|
|
||||||
if snapshot.SecondaryUsedPercent != nil {
|
|
||||||
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
|
|
||||||
}
|
|
||||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
|
||||||
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
|
||||||
}
|
|
||||||
if snapshot.SecondaryWindowMinutes != nil {
|
|
||||||
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,19 +21,50 @@ type stubOpenAIAccountRepo struct {
|
|||||||
accounts []Account
|
accounts []Account
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||||
|
for i := range r.accounts {
|
||||||
|
if r.accounts[i].ID == id {
|
||||||
|
return &r.accounts[i], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("account not found")
|
||||||
|
}
|
||||||
|
|
||||||
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||||
return append([]Account(nil), r.accounts...), nil
|
var result []Account
|
||||||
|
for _, acc := range r.accounts {
|
||||||
|
if acc.Platform == platform {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||||
return append([]Account(nil), r.accounts...), nil
|
var result []Account
|
||||||
|
for _, acc := range r.accounts {
|
||||||
|
if acc.Platform == platform {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubConcurrencyCache struct {
|
type stubConcurrencyCache struct {
|
||||||
ConcurrencyCache
|
ConcurrencyCache
|
||||||
|
loadBatchErr error
|
||||||
|
loadMap map[int64]*AccountLoadInfo
|
||||||
|
acquireResults map[int64]bool
|
||||||
|
waitCounts map[int64]int
|
||||||
|
skipDefaultLoad bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
if c.acquireResults != nil {
|
||||||
|
if result, ok := c.acquireResults[accountID]; ok {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,8 +73,25 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||||
|
if c.loadBatchErr != nil {
|
||||||
|
return nil, c.loadBatchErr
|
||||||
|
}
|
||||||
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||||
|
if c.skipDefaultLoad && c.loadMap != nil {
|
||||||
|
for _, acc := range accounts {
|
||||||
|
if load, ok := c.loadMap[acc.ID]; ok {
|
||||||
|
out[acc.ID] = load
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
for _, acc := range accounts {
|
for _, acc := range accounts {
|
||||||
|
if c.loadMap != nil {
|
||||||
|
if load, ok := c.loadMap[acc.ID]; ok {
|
||||||
|
out[acc.ID] = load
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
@@ -92,6 +140,51 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
if c.waitCounts != nil {
|
||||||
|
if count, ok := c.waitCounts[accountID]; ok {
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubGatewayCache struct {
|
||||||
|
sessionBindings map[string]int64
|
||||||
|
deletedSessions map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||||
|
if id, ok := c.sessionBindings[sessionHash]; ok {
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||||
|
if c.sessionBindings == nil {
|
||||||
|
c.sessionBindings = make(map[string]int64)
|
||||||
|
}
|
||||||
|
c.sessionBindings[sessionHash] = accountID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||||
|
if c.sessionBindings == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.deletedSessions == nil {
|
||||||
|
c.deletedSessions = make(map[string]int)
|
||||||
|
}
|
||||||
|
c.deletedSessions[sessionHash]++
|
||||||
|
delete(c.sessionBindings, sessionHash)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
resetAt := now.Add(10 * time.Minute)
|
resetAt := now.Add(10 * time.Minute)
|
||||||
@@ -182,6 +275,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
|
||||||
|
sessionHash := "session-1"
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{
|
||||||
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||||
|
}
|
||||||
|
if acc == nil || acc.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2, got %+v", acc)
|
||||||
|
}
|
||||||
|
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||||
|
t.Fatalf("expected sticky session to be deleted")
|
||||||
|
}
|
||||||
|
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||||
|
t.Fatalf("expected sticky session to bind to account 2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
|
||||||
|
sessionHash := "session-2"
|
||||||
|
groupID := int64(1)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{
|
||||||
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2, got %+v", selection)
|
||||||
|
}
|
||||||
|
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||||
|
t.Fatalf("expected sticky session to be deleted")
|
||||||
|
}
|
||||||
|
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||||
|
t.Fatalf("expected sticky session to bind to account 2")
|
||||||
|
}
|
||||||
|
if selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for unsupported model")
|
||||||
|
}
|
||||||
|
if acc != nil {
|
||||||
|
t.Fatalf("expected nil account for unsupported model")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "supporting model") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
concurrencyCache := stubConcurrencyCache{
|
||||||
|
loadBatchErr: errors.New("load batch failed"),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
t.Fatalf("expected selection")
|
||||||
|
}
|
||||||
|
if selection.Account.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2, got %d", selection.Account.ID)
|
||||||
|
}
|
||||||
|
if cache.sessionBindings["openai:fallback"] != 2 {
|
||||||
|
t.Fatalf("expected sticky session updated")
|
||||||
|
}
|
||||||
|
if selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
concurrencyCache := stubConcurrencyCache{
|
||||||
|
acquireResults: map[int64]bool{1: false},
|
||||||
|
loadMap: map[int64]*AccountLoadInfo{
|
||||||
|
1: {AccountID: 1, LoadRate: 10},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.WaitPlan == nil {
|
||||||
|
t.Fatalf("expected wait plan fallback")
|
||||||
|
}
|
||||||
|
if selection.Account == nil || selection.Account.ID != 1 {
|
||||||
|
t.Fatalf("expected account 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
|
||||||
|
sessionHash := "bind"
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||||
|
}
|
||||||
|
if acc == nil || acc.ID != 1 {
|
||||||
|
t.Fatalf("expected account 1")
|
||||||
|
}
|
||||||
|
if cache.sessionBindings["openai:"+sessionHash] != 1 {
|
||||||
|
t.Fatalf("expected sticky session binding")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
|
||||||
|
sessionHash := "sticky-wait"
|
||||||
|
groupID := int64(1)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{
|
||||||
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||||
|
}
|
||||||
|
concurrencyCache := stubConcurrencyCache{
|
||||||
|
acquireResults: map[int64]bool{1: false},
|
||||||
|
waitCounts: map[int64]int{1: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.WaitPlan == nil {
|
||||||
|
t.Fatalf("expected sticky wait plan")
|
||||||
|
}
|
||||||
|
if selection.Account == nil || selection.Account.ID != 1 {
|
||||||
|
t.Fatalf("expected account 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
concurrencyCache := stubConcurrencyCache{
|
||||||
|
loadMap: map[int64]*AccountLoadInfo{
|
||||||
|
1: {AccountID: 1, LoadRate: 80},
|
||||||
|
2: {AccountID: 2, LoadRate: 10},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2")
|
||||||
|
}
|
||||||
|
if cache.sessionBindings["openai:load"] != 2 {
|
||||||
|
t.Fatalf("expected sticky session updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
|
||||||
|
sessionHash := "excluded"
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{
|
||||||
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
excluded := map[int64]struct{}{1: {}}
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||||
|
}
|
||||||
|
if acc == nil || acc.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
|
||||||
|
sessionHash := "non-openai"
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{
|
||||||
|
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||||
|
}
|
||||||
|
if acc == nil || acc.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
|
||||||
|
repo := stubOpenAIAccountRepo{accounts: []Account{}}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for no accounts")
|
||||||
|
}
|
||||||
|
if acc != nil {
|
||||||
|
t.Fatalf("expected nil account")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "no available OpenAI accounts") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
resetAt := time.Now().Add(1 * time.Hour)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
concurrencyCache := stubConcurrencyCache{}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for no candidates")
|
||||||
|
}
|
||||||
|
if selection != nil {
|
||||||
|
t.Fatalf("expected nil selection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
concurrencyCache := stubConcurrencyCache{
|
||||||
|
loadMap: map[int64]*AccountLoadInfo{
|
||||||
|
1: {AccountID: 1, LoadRate: 100},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.WaitPlan == nil {
|
||||||
|
t.Fatalf("expected wait plan")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
concurrencyCache := stubConcurrencyCache{
|
||||||
|
loadBatchErr: errors.New("load batch failed"),
|
||||||
|
acquireResults: map[int64]bool{1: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.WaitPlan == nil {
|
||||||
|
t.Fatalf("expected wait plan")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
concurrencyCache := stubConcurrencyCache{
|
||||||
|
loadMap: map[int64]*AccountLoadInfo{
|
||||||
|
1: {AccountID: 1, LoadRate: 50},
|
||||||
|
},
|
||||||
|
skipDefaultLoad: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
|
||||||
|
oldTime := time.Now().Add(-2 * time.Hour)
|
||||||
|
newTime := time.Now().Add(-1 * time.Hour)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||||
|
}
|
||||||
|
if acc == nil || acc.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
|
||||||
|
groupID := int64(1)
|
||||||
|
lastUsed := time.Now().Add(-1 * time.Hour)
|
||||||
|
repo := stubOpenAIAccountRepo{
|
||||||
|
accounts: []Account{
|
||||||
|
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
|
||||||
|
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache := &stubGatewayCache{}
|
||||||
|
concurrencyCache := stubConcurrencyCache{
|
||||||
|
loadMap: map[int64]*AccountLoadInfo{
|
||||||
|
1: {AccountID: 1, LoadRate: 10},
|
||||||
|
2: {AccountID: 2, LoadRate: 10},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||||
|
t.Fatalf("expected account 2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,12 +36,12 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
|||||||
// Generate PKCE values
|
// Generate PKCE values
|
||||||
state, err := openai.GenerateState()
|
state, err := openai.GenerateState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_STATE_FAILED", "failed to generate state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
codeVerifier, err := openai.GenerateCodeVerifier()
|
codeVerifier, err := openai.GenerateCodeVerifier()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_VERIFIER_FAILED", "failed to generate code verifier: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
||||||
@@ -48,14 +49,17 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
|||||||
// Generate session ID
|
// Generate session ID
|
||||||
sessionID, err := openai.GenerateSessionID()
|
sessionID, err := openai.GenerateSessionID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
return nil, infraerrors.Newf(http.StatusInternalServerError, "OPENAI_OAUTH_SESSION_FAILED", "failed to generate session ID: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get proxy URL if specified
|
// Get proxy URL if specified
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
if proxyID != nil {
|
if proxyID != nil {
|
||||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||||
if err == nil && proxy != nil {
|
if err != nil {
|
||||||
|
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||||
|
}
|
||||||
|
if proxy != nil {
|
||||||
proxyURL = proxy.URL()
|
proxyURL = proxy.URL()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -110,14 +114,17 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
|||||||
// Get session
|
// Get session
|
||||||
session, ok := s.sessionStore.Get(input.SessionID)
|
session, ok := s.sessionStore.Get(input.SessionID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("session not found or expired")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get proxy URL
|
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
|
||||||
proxyURL := session.ProxyURL
|
proxyURL := session.ProxyURL
|
||||||
if input.ProxyID != nil {
|
if input.ProxyID != nil {
|
||||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||||
if err == nil && proxy != nil {
|
if err != nil {
|
||||||
|
return nil, infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
|
||||||
|
}
|
||||||
|
if proxy != nil {
|
||||||
proxyURL = proxy.URL()
|
proxyURL = proxy.URL()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -131,7 +138,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
|||||||
// Exchange code for token
|
// Exchange code for token
|
||||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse ID token to get user info
|
// Parse ID token to get user info
|
||||||
@@ -201,12 +208,12 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
|||||||
// RefreshAccountToken refreshes token for an OpenAI account
|
// RefreshAccountToken refreshes token for an OpenAI account
|
||||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||||
if !account.IsOpenAI() {
|
if !account.IsOpenAI() {
|
||||||
return nil, fmt.Errorf("account is not an OpenAI account")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken := account.GetOpenAIRefreshToken()
|
refreshToken := account.GetOpenAIRefreshToken()
|
||||||
if refreshToken == "" {
|
if refreshToken == "" {
|
||||||
return nil, fmt.Errorf("no refresh token available")
|
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
|
||||||
}
|
}
|
||||||
|
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user