mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 23:42:13 +08:00
Compare commits
166 Commits
feat/api-k
...
v0.1.55
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99cbfa1567 | ||
|
|
3f8c8d70ad | ||
|
|
9c567fad92 | ||
|
|
0abb3a6843 | ||
|
|
3663951d11 | ||
|
|
55fced3942 | ||
|
|
7bbf49fd65 | ||
|
|
eea6c2d02c | ||
|
|
70eaa450db | ||
|
|
55796a118d | ||
|
|
d7fa47d732 | ||
|
|
3d6e01a58f | ||
|
|
f9713e8733 | ||
|
|
0e44829720 | ||
|
|
93db889a10 | ||
|
|
0df7385c4e | ||
|
|
1a3fa6411c | ||
|
|
64614756d1 | ||
|
|
bb1fd54d4d | ||
|
|
d85288a6c0 | ||
|
|
3402acb606 | ||
|
|
7fdc25df3c | ||
|
|
ea699cbdc2 | ||
|
|
fe6a3f4267 | ||
|
|
fe8198c8cd | ||
|
|
675e61385f | ||
|
|
67acac1082 | ||
|
|
d02e1db018 | ||
|
|
0da515071b | ||
|
|
524d80ae1c | ||
|
|
3b71bc3df1 | ||
|
|
22ef9534e0 | ||
|
|
c206d12d5c | ||
|
|
6ad29a470c | ||
|
|
2d45e61a9b | ||
|
|
b98fb013ae | ||
|
|
345a965fa3 | ||
|
|
c02c120579 | ||
|
|
4da681f58a | ||
|
|
68ba866c38 | ||
|
|
9622347faa | ||
|
|
8363663ea8 | ||
|
|
b588ea194c | ||
|
|
465ba76788 | ||
|
|
cf313d5761 | ||
|
|
9618cb5643 | ||
|
|
8c1958c9ad | ||
|
|
2db34139f0 | ||
|
|
9c02ab789d | ||
|
|
e0cccf6ed2 | ||
|
|
89c1a41305 | ||
|
|
202ec21bab | ||
|
|
6dcb27632e | ||
|
|
3141aa5144 | ||
|
|
5443efd7d7 | ||
|
|
62771583e7 | ||
|
|
5526f122b7 | ||
|
|
9c144587fe | ||
|
|
098bf5a1e8 | ||
|
|
4c37ca71ee | ||
|
|
0c52809591 | ||
|
|
53e730f8d5 | ||
|
|
8e248e0853 | ||
|
|
2a0758bdfe | ||
|
|
f55ba3f6c1 | ||
|
|
db51e65b42 | ||
|
|
72a2ed958b | ||
|
|
d0b91a40d4 | ||
|
|
bd74bf7994 | ||
|
|
f28d4b78e7 | ||
|
|
7536dbfee5 | ||
|
|
b76cc583fb | ||
|
|
955af6b3ec | ||
|
|
1073317a3e | ||
|
|
839ab37d40 | ||
|
|
9dd0ef187d | ||
|
|
fd8473f267 | ||
|
|
cc4910dd30 | ||
|
|
50de5d05b0 | ||
|
|
7844dc4f2d | ||
|
|
c48795a948 | ||
|
|
19b67e89a2 | ||
|
|
f017fd97c1 | ||
|
|
ce3336e3f4 | ||
|
|
54c5788b86 | ||
|
|
4cb7b26f03 | ||
|
|
3dfb62e996 | ||
|
|
d5c711d081 | ||
|
|
73b62bb15c | ||
|
|
18b8bd43ad | ||
|
|
8fffcd8091 | ||
|
|
c8e3a476fc | ||
|
|
808cee9665 | ||
|
|
92eafbc2a6 | ||
|
|
2548800c3f | ||
|
|
9dce8a5388 | ||
|
|
76484bd5c9 | ||
|
|
e4ed35fe01 | ||
|
|
f5e45c1a8a | ||
|
|
a2f83ff032 | ||
|
|
2b2f7a6dec | ||
|
|
49c15c0d44 | ||
|
|
1b938b2003 | ||
|
|
5f80760a8c | ||
|
|
dd59e872ff | ||
|
|
aa1a3b9a74 | ||
|
|
32953405b1 | ||
|
|
c1a3dd41dd | ||
|
|
63dc6a68df | ||
|
|
a39316e004 | ||
|
|
988b4d0254 | ||
|
|
f541636840 | ||
|
|
48613558d4 | ||
|
|
4b66ee2f8f | ||
|
|
abbde130ab | ||
|
|
ccb8144557 | ||
|
|
1240c78ef6 | ||
|
|
66c8b6f2bc | ||
|
|
6271a33d08 | ||
|
|
5364011a5b | ||
|
|
d78f42d2fd | ||
|
|
1a869547d7 | ||
|
|
e4bc9f6fb0 | ||
|
|
e5857161ff | ||
|
|
abdc4f39cb | ||
|
|
7ebca553ef | ||
|
|
c2962752eb | ||
|
|
ab5839b461 | ||
|
|
89a725a433 | ||
|
|
645609d441 | ||
|
|
fc4ea65936 | ||
|
|
d75cd820b0 | ||
|
|
cb3e08dda4 | ||
|
|
44a93c1922 | ||
|
|
9cba595fd0 | ||
|
|
56fc2764e4 | ||
|
|
0c4f1762c9 | ||
|
|
c2c865b0cb | ||
|
|
a66d318820 | ||
|
|
a16f72f52e | ||
|
|
99e2391b2a | ||
|
|
80c1cdf024 | ||
|
|
9d0a4f3d68 | ||
|
|
1a641392d9 | ||
|
|
36b817d008 | ||
|
|
24d19a5f78 | ||
|
|
13ae0ce7b0 | ||
|
|
3a67002cfe | ||
|
|
eb06006d6c | ||
|
|
c48dc097ff | ||
|
|
585257d340 | ||
|
|
8ae75e7f6e | ||
|
|
fc32b57798 | ||
|
|
337a188660 | ||
|
|
11d063e3c4 | ||
|
|
e846458009 | ||
|
|
2d123a11ad | ||
|
|
fcdf839b6b | ||
|
|
d55dd56fd2 | ||
|
|
e0d12b46d8 | ||
|
|
f3ed95d4de | ||
|
|
5baa8b5673 | ||
|
|
bb5303272b | ||
|
|
d55866d375 | ||
|
|
4b9e47cec9 | ||
|
|
7a06c4873e |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -126,6 +126,4 @@ backend/cmd/server/server
|
|||||||
deploy/docker-compose.override.yml
|
deploy/docker-compose.override.yml
|
||||||
.gocache/
|
.gocache/
|
||||||
vite.config.js
|
vite.config.js
|
||||||
!docs/
|
|
||||||
docs/*
|
docs/*
|
||||||
!docs/dependency-security.md
|
|
||||||
|
|||||||
@@ -57,6 +57,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## OpenAI Responses 兼容注意事项
|
||||||
|
|
||||||
|
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||||
|
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 方式一:脚本安装(推荐)
|
### 方式一:脚本安装(推荐)
|
||||||
|
|||||||
2
backend/.dockerignore
Normal file
2
backend/.dockerignore
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
.cache/
|
||||||
|
.DS_Store
|
||||||
@@ -18,6 +18,12 @@ linters:
|
|||||||
list-mode: original
|
list-mode: original
|
||||||
files:
|
files:
|
||||||
- "**/internal/service/**"
|
- "**/internal/service/**"
|
||||||
|
- "!**/internal/service/ops_aggregation_service.go"
|
||||||
|
- "!**/internal/service/ops_alert_evaluator_service.go"
|
||||||
|
- "!**/internal/service/ops_cleanup_service.go"
|
||||||
|
- "!**/internal/service/ops_metrics_collector.go"
|
||||||
|
- "!**/internal/service/ops_scheduled_report_service.go"
|
||||||
|
- "!**/internal/service/wire.go"
|
||||||
deny:
|
deny:
|
||||||
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
|
- pkg: github.com/Wei-Shaw/sub2api/internal/repository
|
||||||
desc: "service must not import repository"
|
desc: "service must not import repository"
|
||||||
|
|||||||
@@ -62,6 +62,12 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
|||||||
func provideCleanup(
|
func provideCleanup(
|
||||||
entClient *ent.Client,
|
entClient *ent.Client,
|
||||||
rdb *redis.Client,
|
rdb *redis.Client,
|
||||||
|
opsMetricsCollector *service.OpsMetricsCollector,
|
||||||
|
opsAggregation *service.OpsAggregationService,
|
||||||
|
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||||
|
opsCleanup *service.OpsCleanupService,
|
||||||
|
opsScheduledReport *service.OpsScheduledReportService,
|
||||||
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
pricing *service.PricingService,
|
pricing *service.PricingService,
|
||||||
@@ -81,6 +87,42 @@ func provideCleanup(
|
|||||||
name string
|
name string
|
||||||
fn func() error
|
fn func() error
|
||||||
}{
|
}{
|
||||||
|
{"OpsScheduledReportService", func() error {
|
||||||
|
if opsScheduledReport != nil {
|
||||||
|
opsScheduledReport.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpsCleanupService", func() error {
|
||||||
|
if opsCleanup != nil {
|
||||||
|
opsCleanup.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpsAlertEvaluatorService", func() error {
|
||||||
|
if opsAlertEvaluator != nil {
|
||||||
|
opsAlertEvaluator.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpsAggregationService", func() error {
|
||||||
|
if opsAggregation != nil {
|
||||||
|
opsAggregation.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpsMetricsCollector", func() error {
|
||||||
|
if opsMetricsCollector != nil {
|
||||||
|
opsMetricsCollector.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"SchedulerSnapshotService", func() error {
|
||||||
|
if schedulerSnapshot != nil {
|
||||||
|
schedulerSnapshot.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"TokenRefreshService", func() error {
|
{"TokenRefreshService", func() error {
|
||||||
tokenRefresh.Stop()
|
tokenRefresh.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -55,31 +55,36 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingCache := repository.NewBillingCache(redisClient)
|
billingCache := repository.NewBillingCache(redisClient)
|
||||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
||||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
|
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
|
||||||
userService := service.NewUserService(userRepository)
|
|
||||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
|
||||||
userHandler := handler.NewUserHandler(userService)
|
|
||||||
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
||||||
groupRepository := repository.NewGroupRepository(client, db)
|
groupRepository := repository.NewGroupRepository(client, db)
|
||||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||||
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||||
|
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||||
|
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
||||||
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
|
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
||||||
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
dashboardService := service.NewDashboardService(usageLogRepository)
|
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
||||||
dashboardHandler := admin.NewDashboardHandler(dashboardService)
|
timingWheelService := service.ProvideTimingWheelService()
|
||||||
|
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||||
|
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
||||||
|
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||||
accountRepository := repository.NewAccountRepository(client, db)
|
accountRepository := repository.NewAccountRepository(client, db)
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService)
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
@@ -92,7 +97,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||||
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache)
|
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||||
|
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
|
||||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||||
usageCache := service.NewUsageCache()
|
usageCache := service.NewUsageCache()
|
||||||
@@ -106,6 +112,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
|
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||||
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
|
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
@@ -115,7 +124,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
proxyHandler := admin.NewProxyHandler(adminService)
|
proxyHandler := admin.NewProxyHandler(adminService)
|
||||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||||
promoHandler := admin.NewPromoHandler(promoService)
|
promoHandler := admin.NewPromoHandler(promoService)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService)
|
opsRepository := repository.NewOpsRepository(db)
|
||||||
|
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||||
|
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
|
identityService := service.NewIdentityService(identityCache)
|
||||||
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||||
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||||
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
||||||
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||||
|
opsHandler := admin.NewOpsHandler(opsService)
|
||||||
updateCache := repository.NewUpdateCache(redisClient)
|
updateCache := repository.NewUpdateCache(redisClient)
|
||||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||||
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||||
@@ -127,32 +151,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||||
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
||||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
|
||||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
|
||||||
identityCache := repository.NewIdentityCache(redisClient)
|
|
||||||
identityService := service.NewIdentityService(identityCache)
|
|
||||||
timingWheelService := service.ProvideTimingWheelService()
|
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, settingService, redisClient)
|
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient)
|
||||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||||
|
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
|
||||||
|
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
|
||||||
|
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||||
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
v := provideCleanup(client, redisClient, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -177,6 +193,12 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
|||||||
func provideCleanup(
|
func provideCleanup(
|
||||||
entClient *ent.Client,
|
entClient *ent.Client,
|
||||||
rdb *redis.Client,
|
rdb *redis.Client,
|
||||||
|
opsMetricsCollector *service.OpsMetricsCollector,
|
||||||
|
opsAggregation *service.OpsAggregationService,
|
||||||
|
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||||
|
opsCleanup *service.OpsCleanupService,
|
||||||
|
opsScheduledReport *service.OpsScheduledReportService,
|
||||||
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
pricing *service.PricingService,
|
pricing *service.PricingService,
|
||||||
@@ -195,6 +217,42 @@ func provideCleanup(
|
|||||||
name string
|
name string
|
||||||
fn func() error
|
fn func() error
|
||||||
}{
|
}{
|
||||||
|
{"OpsScheduledReportService", func() error {
|
||||||
|
if opsScheduledReport != nil {
|
||||||
|
opsScheduledReport.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpsCleanupService", func() error {
|
||||||
|
if opsCleanup != nil {
|
||||||
|
opsCleanup.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpsAlertEvaluatorService", func() error {
|
||||||
|
if opsAlertEvaluator != nil {
|
||||||
|
opsAlertEvaluator.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpsAggregationService", func() error {
|
||||||
|
if opsAggregation != nil {
|
||||||
|
opsAggregation.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"OpsMetricsCollector", func() error {
|
||||||
|
if opsMetricsCollector != nil {
|
||||||
|
opsMetricsCollector.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
{"SchedulerSnapshotService", func() error {
|
||||||
|
if schedulerSnapshot != nil {
|
||||||
|
schedulerSnapshot.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"TokenRefreshService", func() error {
|
{"TokenRefreshService", func() error {
|
||||||
tokenRefresh.Stop()
|
tokenRefresh.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ require (
|
|||||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/google/wire v0.7.0
|
github.com/google/wire v0.7.0
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/imroc/req/v3 v3.57.0
|
github.com/imroc/req/v3 v3.57.0
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
github.com/redis/go-redis/v9 v9.17.2
|
github.com/redis/go-redis/v9 v9.17.2
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.6
|
||||||
github.com/spf13/viper v1.18.2
|
github.com/spf13/viper v1.18.2
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
|
||||||
@@ -44,11 +46,13 @@ require (
|
|||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
github.com/docker/go-connections v0.6.0 // indirect
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/ebitengine/purego v0.8.4 // indirect
|
github.com/ebitengine/purego v0.8.4 // indirect
|
||||||
github.com/fatih/color v1.18.0 // indirect
|
github.com/fatih/color v1.18.0 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
@@ -104,9 +108,9 @@ require (
|
|||||||
github.com/quic-go/quic-go v0.57.1 // indirect
|
github.com/quic-go/quic-go v0.57.1 // indirect
|
||||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
|
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
@@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
|
|||||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
@@ -113,6 +117,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||||
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
|
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
|
||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
||||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||||
@@ -220,6 +226,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr
|
|||||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
|
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||||
|
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
|
|||||||
@@ -36,33 +36,29 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
Security SecurityConfig `mapstructure:"security"`
|
Security SecurityConfig `mapstructure:"security"`
|
||||||
Billing BillingConfig `mapstructure:"billing"`
|
Billing BillingConfig `mapstructure:"billing"`
|
||||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
Ops OpsConfig `mapstructure:"ops"`
|
||||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
Pricing PricingConfig `mapstructure:"pricing"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
Pricing PricingConfig `mapstructure:"pricing"`
|
||||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||||
Update UpdateConfig `mapstructure:"update"`
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
}
|
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||||
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
// UpdateConfig 在线更新相关配置
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||||
type UpdateConfig struct {
|
Update UpdateConfig `mapstructure:"update"`
|
||||||
// ProxyURL 用于访问 GitHub 的代理地址
|
|
||||||
// 支持 http/https/socks5/socks5h 协议
|
|
||||||
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
|
|
||||||
ProxyURL string `mapstructure:"proxy_url"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiConfig struct {
|
type GeminiConfig struct {
|
||||||
@@ -87,6 +83,33 @@ type GeminiTierQuotaConfig struct {
|
|||||||
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
|
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UpdateConfig struct {
|
||||||
|
// ProxyURL 用于访问 GitHub 的代理地址
|
||||||
|
// 支持 http/https/socks5/socks5h 协议
|
||||||
|
// 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080"
|
||||||
|
ProxyURL string `mapstructure:"proxy_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LinuxDoConnectConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
ClientID string `mapstructure:"client_id"`
|
||||||
|
ClientSecret string `mapstructure:"client_secret"`
|
||||||
|
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||||
|
TokenURL string `mapstructure:"token_url"`
|
||||||
|
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||||
|
Scopes string `mapstructure:"scopes"`
|
||||||
|
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
|
||||||
|
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
|
||||||
|
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
|
||||||
|
UsePKCE bool `mapstructure:"use_pkce"`
|
||||||
|
|
||||||
|
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
|
||||||
|
// 为空时,服务端会尝试一组常见字段名。
|
||||||
|
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
|
||||||
|
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
|
||||||
|
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||||
|
}
|
||||||
|
|
||||||
// TokenRefreshConfig OAuth token自动刷新配置
|
// TokenRefreshConfig OAuth token自动刷新配置
|
||||||
type TokenRefreshConfig struct {
|
type TokenRefreshConfig struct {
|
||||||
// 是否启用自动刷新
|
// 是否启用自动刷新
|
||||||
@@ -247,6 +270,29 @@ type GatewaySchedulingConfig struct {
|
|||||||
|
|
||||||
// 过期槽位清理周期(0 表示禁用)
|
// 过期槽位清理周期(0 表示禁用)
|
||||||
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
|
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
|
||||||
|
|
||||||
|
// 受控回源配置
|
||||||
|
DbFallbackEnabled bool `mapstructure:"db_fallback_enabled"`
|
||||||
|
// 受控回源超时(秒),0 表示不额外收紧超时
|
||||||
|
DbFallbackTimeoutSeconds int `mapstructure:"db_fallback_timeout_seconds"`
|
||||||
|
// 受控回源限流(实例级 QPS),0 表示不限制
|
||||||
|
DbFallbackMaxQPS int `mapstructure:"db_fallback_max_qps"`
|
||||||
|
|
||||||
|
// Outbox 轮询与滞后阈值配置
|
||||||
|
// Outbox 轮询周期(秒)
|
||||||
|
OutboxPollIntervalSeconds int `mapstructure:"outbox_poll_interval_seconds"`
|
||||||
|
// Outbox 滞后告警阈值(秒)
|
||||||
|
OutboxLagWarnSeconds int `mapstructure:"outbox_lag_warn_seconds"`
|
||||||
|
// Outbox 触发强制重建阈值(秒)
|
||||||
|
OutboxLagRebuildSeconds int `mapstructure:"outbox_lag_rebuild_seconds"`
|
||||||
|
// Outbox 连续滞后触发次数
|
||||||
|
OutboxLagRebuildFailures int `mapstructure:"outbox_lag_rebuild_failures"`
|
||||||
|
// Outbox 积压触发重建阈值(行数)
|
||||||
|
OutboxBacklogRebuildRows int `mapstructure:"outbox_backlog_rebuild_rows"`
|
||||||
|
|
||||||
|
// 全量重建周期配置
|
||||||
|
// 全量重建周期(秒),0 表示禁用
|
||||||
|
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerConfig) Address() string {
|
func (s *ServerConfig) Address() string {
|
||||||
@@ -329,6 +375,47 @@ func (r *RedisConfig) Address() string {
|
|||||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpsConfig struct {
|
||||||
|
// Enabled controls whether ops features should run.
|
||||||
|
//
|
||||||
|
// NOTE: vNext still has a DB-backed feature flag (ops_monitoring_enabled) for runtime on/off.
|
||||||
|
// This config flag is the "hard switch" for deployments that want to disable ops completely.
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
|
||||||
|
// UsePreaggregatedTables prefers ops_metrics_hourly/daily for long-window dashboard queries.
|
||||||
|
UsePreaggregatedTables bool `mapstructure:"use_preaggregated_tables"`
|
||||||
|
|
||||||
|
// Cleanup controls periodic deletion of old ops data to prevent unbounded growth.
|
||||||
|
Cleanup OpsCleanupConfig `mapstructure:"cleanup"`
|
||||||
|
|
||||||
|
// MetricsCollectorCache controls Redis caching for expensive per-window collector queries.
|
||||||
|
MetricsCollectorCache OpsMetricsCollectorCacheConfig `mapstructure:"metrics_collector_cache"`
|
||||||
|
|
||||||
|
// Pre-aggregation configuration.
|
||||||
|
Aggregation OpsAggregationConfig `mapstructure:"aggregation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpsCleanupConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
Schedule string `mapstructure:"schedule"`
|
||||||
|
|
||||||
|
// Retention days (0 disables that cleanup target).
|
||||||
|
//
|
||||||
|
// vNext requirement: default 30 days across ops datasets.
|
||||||
|
ErrorLogRetentionDays int `mapstructure:"error_log_retention_days"`
|
||||||
|
MinuteMetricsRetentionDays int `mapstructure:"minute_metrics_retention_days"`
|
||||||
|
HourlyMetricsRetentionDays int `mapstructure:"hourly_metrics_retention_days"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpsAggregationConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpsMetricsCollectorCacheConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
TTL time.Duration `mapstructure:"ttl"`
|
||||||
|
}
|
||||||
|
|
||||||
type JWTConfig struct {
|
type JWTConfig struct {
|
||||||
Secret string `mapstructure:"secret"`
|
Secret string `mapstructure:"secret"`
|
||||||
ExpireHour int `mapstructure:"expire_hour"`
|
ExpireHour int `mapstructure:"expire_hour"`
|
||||||
@@ -338,30 +425,6 @@ type TurnstileConfig struct {
|
|||||||
Required bool `mapstructure:"required"`
|
Required bool `mapstructure:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
|
|
||||||
//
|
|
||||||
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
|
|
||||||
// 这里是用于登录 Sub2API 本身的用户体系。
|
|
||||||
type LinuxDoConnectConfig struct {
|
|
||||||
Enabled bool `mapstructure:"enabled"`
|
|
||||||
ClientID string `mapstructure:"client_id"`
|
|
||||||
ClientSecret string `mapstructure:"client_secret"`
|
|
||||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
|
||||||
TokenURL string `mapstructure:"token_url"`
|
|
||||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
|
||||||
Scopes string `mapstructure:"scopes"`
|
|
||||||
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
|
|
||||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
|
|
||||||
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
|
|
||||||
UsePKCE bool `mapstructure:"use_pkce"`
|
|
||||||
|
|
||||||
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
|
|
||||||
// 为空时,服务端会尝试一组常见字段名。
|
|
||||||
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
|
|
||||||
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
|
|
||||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DefaultConfig struct {
|
type DefaultConfig struct {
|
||||||
AdminEmail string `mapstructure:"admin_email"`
|
AdminEmail string `mapstructure:"admin_email"`
|
||||||
AdminPassword string `mapstructure:"admin_password"`
|
AdminPassword string `mapstructure:"admin_password"`
|
||||||
@@ -375,6 +438,55 @@ type RateLimitConfig struct {
|
|||||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheConfig API Key 认证缓存配置
|
||||||
|
type APIKeyAuthCacheConfig struct {
|
||||||
|
L1Size int `mapstructure:"l1_size"`
|
||||||
|
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
|
||||||
|
L2TTLSeconds int `mapstructure:"l2_ttl_seconds"`
|
||||||
|
NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"`
|
||||||
|
JitterPercent int `mapstructure:"jitter_percent"`
|
||||||
|
Singleflight bool `mapstructure:"singleflight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardCacheConfig 仪表盘统计缓存配置
|
||||||
|
type DashboardCacheConfig struct {
|
||||||
|
// Enabled: 是否启用仪表盘缓存
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
// KeyPrefix: Redis key 前缀,用于多环境隔离
|
||||||
|
KeyPrefix string `mapstructure:"key_prefix"`
|
||||||
|
// StatsFreshTTLSeconds: 缓存命中认为“新鲜”的时间窗口(秒)
|
||||||
|
StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"`
|
||||||
|
// StatsTTLSeconds: Redis 缓存总 TTL(秒)
|
||||||
|
StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"`
|
||||||
|
// StatsRefreshTimeoutSeconds: 异步刷新超时(秒)
|
||||||
|
StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardAggregationConfig 仪表盘预聚合配置
|
||||||
|
type DashboardAggregationConfig struct {
|
||||||
|
// Enabled: 是否启用预聚合作业
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
// IntervalSeconds: 聚合刷新间隔(秒)
|
||||||
|
IntervalSeconds int `mapstructure:"interval_seconds"`
|
||||||
|
// LookbackSeconds: 回看窗口(秒)
|
||||||
|
LookbackSeconds int `mapstructure:"lookback_seconds"`
|
||||||
|
// BackfillEnabled: 是否允许全量回填
|
||||||
|
BackfillEnabled bool `mapstructure:"backfill_enabled"`
|
||||||
|
// BackfillMaxDays: 回填最大跨度(天)
|
||||||
|
BackfillMaxDays int `mapstructure:"backfill_max_days"`
|
||||||
|
// Retention: 各表保留窗口(天)
|
||||||
|
Retention DashboardAggregationRetentionConfig `mapstructure:"retention"`
|
||||||
|
// RecomputeDays: 启动时重算最近 N 天
|
||||||
|
RecomputeDays int `mapstructure:"recompute_days"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||||
|
type DashboardAggregationRetentionConfig struct {
|
||||||
|
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||||
|
HourlyDays int `mapstructure:"hourly_days"`
|
||||||
|
DailyDays int `mapstructure:"daily_days"`
|
||||||
|
}
|
||||||
|
|
||||||
func NormalizeRunMode(value string) string {
|
func NormalizeRunMode(value string) string {
|
||||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||||
switch normalized {
|
switch normalized {
|
||||||
@@ -440,6 +552,7 @@ func Load() (*Config, error) {
|
|||||||
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
|
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
|
||||||
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
|
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
|
||||||
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
|
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
|
||||||
|
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||||
@@ -478,81 +591,6 @@ func Load() (*Config, error) {
|
|||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
|
|
||||||
func ValidateAbsoluteHTTPURL(raw string) error {
|
|
||||||
raw = strings.TrimSpace(raw)
|
|
||||||
if raw == "" {
|
|
||||||
return fmt.Errorf("empty url")
|
|
||||||
}
|
|
||||||
u, err := url.Parse(raw)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !u.IsAbs() {
|
|
||||||
return fmt.Errorf("must be absolute")
|
|
||||||
}
|
|
||||||
if !isHTTPScheme(u.Scheme) {
|
|
||||||
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(u.Host) == "" {
|
|
||||||
return fmt.Errorf("missing host")
|
|
||||||
}
|
|
||||||
if u.Fragment != "" {
|
|
||||||
return fmt.Errorf("must not include fragment")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateFrontendRedirectURL 校验前端回调地址:
|
|
||||||
// - 允许同源相对路径(以 / 开头)
|
|
||||||
// - 或绝对 http(s) URL(禁止 fragment)
|
|
||||||
func ValidateFrontendRedirectURL(raw string) error {
|
|
||||||
raw = strings.TrimSpace(raw)
|
|
||||||
if raw == "" {
|
|
||||||
return fmt.Errorf("empty url")
|
|
||||||
}
|
|
||||||
if strings.ContainsAny(raw, "\r\n") {
|
|
||||||
return fmt.Errorf("contains invalid characters")
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(raw, "/") {
|
|
||||||
if strings.HasPrefix(raw, "//") {
|
|
||||||
return fmt.Errorf("must not start with //")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
u, err := url.Parse(raw)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !u.IsAbs() {
|
|
||||||
return fmt.Errorf("must be absolute http(s) url or relative path")
|
|
||||||
}
|
|
||||||
if !isHTTPScheme(u.Scheme) {
|
|
||||||
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(u.Host) == "" {
|
|
||||||
return fmt.Errorf("missing host")
|
|
||||||
}
|
|
||||||
if u.Fragment != "" {
|
|
||||||
return fmt.Errorf("must not include fragment")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isHTTPScheme(scheme string) bool {
|
|
||||||
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
|
|
||||||
}
|
|
||||||
|
|
||||||
func warnIfInsecureURL(field, raw string) {
|
|
||||||
u, err := url.Parse(strings.TrimSpace(raw))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.EqualFold(u.Scheme, "http") {
|
|
||||||
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setDefaults() {
|
func setDefaults() {
|
||||||
viper.SetDefault("run_mode", RunModeStandard)
|
viper.SetDefault("run_mode", RunModeStandard)
|
||||||
|
|
||||||
@@ -602,7 +640,7 @@ func setDefaults() {
|
|||||||
// Turnstile
|
// Turnstile
|
||||||
viper.SetDefault("turnstile.required", false)
|
viper.SetDefault("turnstile.required", false)
|
||||||
|
|
||||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
// LinuxDo Connect OAuth 登录
|
||||||
viper.SetDefault("linuxdo_connect.enabled", false)
|
viper.SetDefault("linuxdo_connect.enabled", false)
|
||||||
viper.SetDefault("linuxdo_connect.client_id", "")
|
viper.SetDefault("linuxdo_connect.client_id", "")
|
||||||
viper.SetDefault("linuxdo_connect.client_secret", "")
|
viper.SetDefault("linuxdo_connect.client_secret", "")
|
||||||
@@ -641,6 +679,20 @@ func setDefaults() {
|
|||||||
viper.SetDefault("redis.pool_size", 128)
|
viper.SetDefault("redis.pool_size", 128)
|
||||||
viper.SetDefault("redis.min_idle_conns", 10)
|
viper.SetDefault("redis.min_idle_conns", 10)
|
||||||
|
|
||||||
|
// Ops (vNext)
|
||||||
|
viper.SetDefault("ops.enabled", true)
|
||||||
|
viper.SetDefault("ops.use_preaggregated_tables", false)
|
||||||
|
viper.SetDefault("ops.cleanup.enabled", true)
|
||||||
|
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
|
||||||
|
// Retention days: vNext defaults to 30 days across ops datasets.
|
||||||
|
viper.SetDefault("ops.cleanup.error_log_retention_days", 30)
|
||||||
|
viper.SetDefault("ops.cleanup.minute_metrics_retention_days", 30)
|
||||||
|
viper.SetDefault("ops.cleanup.hourly_metrics_retention_days", 30)
|
||||||
|
viper.SetDefault("ops.aggregation.enabled", true)
|
||||||
|
viper.SetDefault("ops.metrics_collector_cache.enabled", true)
|
||||||
|
// TTL should be slightly larger than collection interval (1m) to maximize cross-replica cache hits.
|
||||||
|
viper.SetDefault("ops.metrics_collector_cache.ttl", 65*time.Second)
|
||||||
|
|
||||||
// JWT
|
// JWT
|
||||||
viper.SetDefault("jwt.secret", "")
|
viper.SetDefault("jwt.secret", "")
|
||||||
viper.SetDefault("jwt.expire_hour", 24)
|
viper.SetDefault("jwt.expire_hour", 24)
|
||||||
@@ -669,9 +721,35 @@ func setDefaults() {
|
|||||||
// Timezone (default to Asia/Shanghai for Chinese users)
|
// Timezone (default to Asia/Shanghai for Chinese users)
|
||||||
viper.SetDefault("timezone", "Asia/Shanghai")
|
viper.SetDefault("timezone", "Asia/Shanghai")
|
||||||
|
|
||||||
|
// API Key auth cache
|
||||||
|
viper.SetDefault("api_key_auth_cache.l1_size", 65535)
|
||||||
|
viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15)
|
||||||
|
viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300)
|
||||||
|
viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30)
|
||||||
|
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
|
||||||
|
viper.SetDefault("api_key_auth_cache.singleflight", true)
|
||||||
|
|
||||||
|
// Dashboard cache
|
||||||
|
viper.SetDefault("dashboard_cache.enabled", true)
|
||||||
|
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
|
||||||
|
viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15)
|
||||||
|
viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30)
|
||||||
|
viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30)
|
||||||
|
|
||||||
|
// Dashboard aggregation
|
||||||
|
viper.SetDefault("dashboard_aggregation.enabled", true)
|
||||||
|
viper.SetDefault("dashboard_aggregation.interval_seconds", 60)
|
||||||
|
viper.SetDefault("dashboard_aggregation.lookback_seconds", 120)
|
||||||
|
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||||
|
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||||
|
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||||
|
|
||||||
// Gateway
|
// Gateway
|
||||||
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||||
viper.SetDefault("gateway.log_upstream_error_body", false)
|
viper.SetDefault("gateway.log_upstream_error_body", true)
|
||||||
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
|
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
|
||||||
viper.SetDefault("gateway.inject_beta_for_apikey", false)
|
viper.SetDefault("gateway.inject_beta_for_apikey", false)
|
||||||
viper.SetDefault("gateway.failover_on_400", false)
|
viper.SetDefault("gateway.failover_on_400", false)
|
||||||
@@ -687,13 +765,22 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
|
||||||
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
|
||||||
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
viper.SetDefault("gateway.stream_keepalive_interval", 10)
|
||||||
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
|
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
|
||||||
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
|
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
|
||||||
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
|
||||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||||
|
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
|
||||||
|
viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0)
|
||||||
|
viper.SetDefault("gateway.scheduling.db_fallback_max_qps", 0)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_poll_interval_seconds", 1)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_lag_warn_seconds", 5)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_seconds", 10)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
|
||||||
|
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
|
||||||
|
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
|
||||||
viper.SetDefault("concurrency.ping_interval", 10)
|
viper.SetDefault("concurrency.ping_interval", 10)
|
||||||
|
|
||||||
// TokenRefresh
|
// TokenRefresh
|
||||||
@@ -710,10 +797,6 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gemini.oauth.client_secret", "")
|
viper.SetDefault("gemini.oauth.client_secret", "")
|
||||||
viper.SetDefault("gemini.oauth.scopes", "")
|
viper.SetDefault("gemini.oauth.scopes", "")
|
||||||
viper.SetDefault("gemini.quota.policy", "")
|
viper.SetDefault("gemini.quota.policy", "")
|
||||||
|
|
||||||
// Update - 在线更新配置
|
|
||||||
// 代理地址为空表示直连 GitHub(适用于海外服务器)
|
|
||||||
viper.SetDefault("update.proxy_url", "")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
@@ -754,7 +837,8 @@ func (c *Config) Validate() error {
|
|||||||
if method == "none" && !c.LinuxDo.UsePKCE {
|
if method == "none" && !c.LinuxDo.UsePKCE {
|
||||||
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
|
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
|
||||||
}
|
}
|
||||||
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
|
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
|
||||||
|
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
|
||||||
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
|
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
|
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
|
||||||
@@ -827,6 +911,78 @@ func (c *Config) Validate() error {
|
|||||||
if c.Redis.MinIdleConns > c.Redis.PoolSize {
|
if c.Redis.MinIdleConns > c.Redis.PoolSize {
|
||||||
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
|
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
|
||||||
}
|
}
|
||||||
|
if c.Dashboard.Enabled {
|
||||||
|
if c.Dashboard.StatsFreshTTLSeconds <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsTTLSeconds <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsRefreshTimeoutSeconds <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsFreshTTLSeconds > c.Dashboard.StatsTTLSeconds {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be <= dashboard_cache.stats_ttl_seconds")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if c.Dashboard.StatsFreshTTLSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsTTLSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsRefreshTimeoutSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Enabled {
|
||||||
|
if c.DashboardAgg.IntervalSeconds <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.LookbackSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.BackfillMaxDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.BackfillEnabled && c.DashboardAgg.BackfillMaxDays == 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.DailyDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.daily_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.RecomputeDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if c.DashboardAgg.IntervalSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.interval_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.LookbackSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.BackfillMaxDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.DailyDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.daily_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.RecomputeDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
|
||||||
|
}
|
||||||
|
}
|
||||||
if c.Gateway.MaxBodySize <= 0 {
|
if c.Gateway.MaxBodySize <= 0 {
|
||||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||||
}
|
}
|
||||||
@@ -897,6 +1053,50 @@ func (c *Config) Validate() error {
|
|||||||
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
|
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
|
||||||
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
|
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.Gateway.Scheduling.DbFallbackTimeoutSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.db_fallback_timeout_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.DbFallbackMaxQPS < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.db_fallback_max_qps must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxPollIntervalSeconds <= 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_poll_interval_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxLagWarnSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_lag_warn_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxLagRebuildSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxLagRebuildFailures <= 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_failures must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxBacklogRebuildRows < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_backlog_rebuild_rows must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.FullRebuildIntervalSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.scheduling.full_rebuild_interval_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.Scheduling.OutboxLagWarnSeconds > 0 &&
|
||||||
|
c.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 &&
|
||||||
|
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
|
||||||
|
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
|
||||||
|
}
|
||||||
|
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
||||||
|
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Ops.Cleanup.ErrorLogRetentionDays < 0 {
|
||||||
|
return fmt.Errorf("ops.cleanup.error_log_retention_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Ops.Cleanup.MinuteMetricsRetentionDays < 0 {
|
||||||
|
return fmt.Errorf("ops.cleanup.minute_metrics_retention_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Ops.Cleanup.HourlyMetricsRetentionDays < 0 {
|
||||||
|
return fmt.Errorf("ops.cleanup.hourly_metrics_retention_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Ops.Cleanup.Enabled && strings.TrimSpace(c.Ops.Cleanup.Schedule) == "" {
|
||||||
|
return fmt.Errorf("ops.cleanup.schedule is required when ops.cleanup.enabled=true")
|
||||||
|
}
|
||||||
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
|
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
|
||||||
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
|
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
|
||||||
}
|
}
|
||||||
@@ -973,3 +1173,77 @@ func GetServerAddress() string {
|
|||||||
port := v.GetInt("server.port")
|
port := v.GetInt("server.port")
|
||||||
return fmt.Sprintf("%s:%d", host, port)
|
return fmt.Sprintf("%s:%d", host, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateAbsoluteHTTPURL 验证是否为有效的绝对 HTTP(S) URL
|
||||||
|
func ValidateAbsoluteHTTPURL(raw string) error {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return fmt.Errorf("empty url")
|
||||||
|
}
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !u.IsAbs() {
|
||||||
|
return fmt.Errorf("must be absolute")
|
||||||
|
}
|
||||||
|
if !isHTTPScheme(u.Scheme) {
|
||||||
|
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(u.Host) == "" {
|
||||||
|
return fmt.Errorf("missing host")
|
||||||
|
}
|
||||||
|
if u.Fragment != "" {
|
||||||
|
return fmt.Errorf("must not include fragment")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateFrontendRedirectURL 验证前端重定向 URL(可以是绝对 URL 或相对路径)
|
||||||
|
func ValidateFrontendRedirectURL(raw string) error {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return fmt.Errorf("empty url")
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(raw, "\r\n") {
|
||||||
|
return fmt.Errorf("contains invalid characters")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(raw, "/") {
|
||||||
|
if strings.HasPrefix(raw, "//") {
|
||||||
|
return fmt.Errorf("must not start with //")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !u.IsAbs() {
|
||||||
|
return fmt.Errorf("must be absolute http(s) url or relative path")
|
||||||
|
}
|
||||||
|
if !isHTTPScheme(u.Scheme) {
|
||||||
|
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(u.Host) == "" {
|
||||||
|
return fmt.Errorf("missing host")
|
||||||
|
}
|
||||||
|
if u.Fragment != "" {
|
||||||
|
return fmt.Errorf("must not include fragment")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议
|
||||||
|
func isHTTPScheme(scheme string) bool {
|
||||||
|
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
|
||||||
|
}
|
||||||
|
|
||||||
|
func warnIfInsecureURL(field, raw string) {
|
||||||
|
u, err := url.Parse(strings.TrimSpace(raw))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.EqualFold(u.Scheme, "http") {
|
||||||
|
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
|||||||
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
|
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
|
||||||
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
|
||||||
}
|
}
|
||||||
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
|
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 120*time.Second {
|
||||||
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
|
t.Fatalf("StickySessionWaitTimeout = %v, want 120s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
|
||||||
}
|
}
|
||||||
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
|
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
|
||||||
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
|
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
|
||||||
@@ -141,3 +141,142 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
|||||||
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
|
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Dashboard.Enabled {
|
||||||
|
t.Fatalf("Dashboard.Enabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.KeyPrefix != "sub2api:" {
|
||||||
|
t.Fatalf("Dashboard.KeyPrefix = %q, want %q", cfg.Dashboard.KeyPrefix, "sub2api:")
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsFreshTTLSeconds != 15 {
|
||||||
|
t.Fatalf("Dashboard.StatsFreshTTLSeconds = %d, want 15", cfg.Dashboard.StatsFreshTTLSeconds)
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsTTLSeconds != 30 {
|
||||||
|
t.Fatalf("Dashboard.StatsTTLSeconds = %d, want 30", cfg.Dashboard.StatsTTLSeconds)
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsRefreshTimeoutSeconds != 30 {
|
||||||
|
t.Fatalf("Dashboard.StatsRefreshTimeoutSeconds = %d, want 30", cfg.Dashboard.StatsRefreshTimeoutSeconds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Dashboard.Enabled = true
|
||||||
|
cfg.Dashboard.StatsFreshTTLSeconds = 10
|
||||||
|
cfg.Dashboard.StatsTTLSeconds = 5
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for stats_fresh_ttl_seconds > stats_ttl_seconds, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "dashboard_cache.stats_fresh_ttl_seconds") {
|
||||||
|
t.Fatalf("Validate() expected stats_fresh_ttl_seconds error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Dashboard.Enabled = false
|
||||||
|
cfg.Dashboard.StatsTTLSeconds = -1
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for negative stats_ttl_seconds, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "dashboard_cache.stats_ttl_seconds") {
|
||||||
|
t.Fatalf("Validate() expected stats_ttl_seconds error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.DashboardAgg.Enabled {
|
||||||
|
t.Fatalf("DashboardAgg.Enabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.IntervalSeconds != 60 {
|
||||||
|
t.Fatalf("DashboardAgg.IntervalSeconds = %d, want 60", cfg.DashboardAgg.IntervalSeconds)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.LookbackSeconds != 120 {
|
||||||
|
t.Fatalf("DashboardAgg.LookbackSeconds = %d, want 120", cfg.DashboardAgg.LookbackSeconds)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.BackfillEnabled {
|
||||||
|
t.Fatalf("DashboardAgg.BackfillEnabled = true, want false")
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.BackfillMaxDays != 31 {
|
||||||
|
t.Fatalf("DashboardAgg.BackfillMaxDays = %d, want 31", cfg.DashboardAgg.BackfillMaxDays)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.DailyDays != 730 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.DailyDays = %d, want 730", cfg.DashboardAgg.Retention.DailyDays)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.RecomputeDays != 2 {
|
||||||
|
t.Fatalf("DashboardAgg.RecomputeDays = %d, want 2", cfg.DashboardAgg.RecomputeDays)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DashboardAgg.Enabled = false
|
||||||
|
cfg.DashboardAgg.IntervalSeconds = -1
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for negative dashboard_aggregation.interval_seconds, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "dashboard_aggregation.interval_seconds") {
|
||||||
|
t.Fatalf("Validate() expected interval_seconds error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DashboardAgg.BackfillEnabled = true
|
||||||
|
cfg.DashboardAgg.BackfillMaxDays = 0
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for dashboard_aggregation.backfill_max_days, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "dashboard_aggregation.backfill_max_days") {
|
||||||
|
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -13,15 +14,17 @@ import (
|
|||||||
|
|
||||||
// DashboardHandler handles admin dashboard statistics
|
// DashboardHandler handles admin dashboard statistics
|
||||||
type DashboardHandler struct {
|
type DashboardHandler struct {
|
||||||
dashboardService *service.DashboardService
|
dashboardService *service.DashboardService
|
||||||
startTime time.Time // Server start time for uptime calculation
|
aggregationService *service.DashboardAggregationService
|
||||||
|
startTime time.Time // Server start time for uptime calculation
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDashboardHandler creates a new admin dashboard handler
|
// NewDashboardHandler creates a new admin dashboard handler
|
||||||
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
|
func NewDashboardHandler(dashboardService *service.DashboardService, aggregationService *service.DashboardAggregationService) *DashboardHandler {
|
||||||
return &DashboardHandler{
|
return &DashboardHandler{
|
||||||
dashboardService: dashboardService,
|
dashboardService: dashboardService,
|
||||||
startTime: time.Now(),
|
aggregationService: aggregationService,
|
||||||
|
startTime: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,6 +117,58 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
|
|||||||
// 性能指标
|
// 性能指标
|
||||||
"rpm": stats.Rpm,
|
"rpm": stats.Rpm,
|
||||||
"tpm": stats.Tpm,
|
"tpm": stats.Tpm,
|
||||||
|
|
||||||
|
// 预聚合新鲜度
|
||||||
|
"hourly_active_users": stats.HourlyActiveUsers,
|
||||||
|
"stats_updated_at": stats.StatsUpdatedAt,
|
||||||
|
"stats_stale": stats.StatsStale,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type DashboardAggregationBackfillRequest struct {
|
||||||
|
Start string `json:"start"`
|
||||||
|
End string `json:"end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackfillAggregation handles triggering aggregation backfill
|
||||||
|
// POST /api/v1/admin/dashboard/aggregation/backfill
|
||||||
|
func (h *DashboardHandler) BackfillAggregation(c *gin.Context) {
|
||||||
|
if h.aggregationService == nil {
|
||||||
|
response.InternalError(c, "Aggregation service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req DashboardAggregationBackfillRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
start, err := time.Parse(time.RFC3339, req.Start)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid start time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
end, err := time.Parse(time.RFC3339, req.End)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid end time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.aggregationService.TriggerBackfill(start, end); err != nil {
|
||||||
|
if errors.Is(err, service.ErrDashboardBackfillDisabled) {
|
||||||
|
response.Forbidden(c, "Backfill is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, service.ErrDashboardBackfillTooLarge) {
|
||||||
|
response.BadRequest(c, "Backfill range too large")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalError(c, "Failed to trigger backfill")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"status": "accepted",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
432
backend/internal/handler/admin/ops_alerts_handler.go
Normal file
432
backend/internal/handler/admin/ops_alerts_handler.go
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gin-gonic/gin/binding"
|
||||||
|
)
|
||||||
|
|
||||||
|
var validOpsAlertMetricTypes = []string{
|
||||||
|
"success_rate",
|
||||||
|
"error_rate",
|
||||||
|
"upstream_error_rate",
|
||||||
|
"p95_latency_ms",
|
||||||
|
"p99_latency_ms",
|
||||||
|
"cpu_usage_percent",
|
||||||
|
"memory_usage_percent",
|
||||||
|
"concurrency_queue_depth",
|
||||||
|
}
|
||||||
|
|
||||||
|
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
|
||||||
|
set := make(map[string]struct{}, len(validOpsAlertMetricTypes))
|
||||||
|
for _, v := range validOpsAlertMetricTypes {
|
||||||
|
set[v] = struct{}{}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}()
|
||||||
|
|
||||||
|
var validOpsAlertOperators = []string{">", "<", ">=", "<=", "==", "!="}
|
||||||
|
|
||||||
|
var validOpsAlertOperatorSet = func() map[string]struct{} {
|
||||||
|
set := make(map[string]struct{}, len(validOpsAlertOperators))
|
||||||
|
for _, v := range validOpsAlertOperators {
|
||||||
|
set[v] = struct{}{}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}()
|
||||||
|
|
||||||
|
var validOpsAlertSeverities = []string{"P0", "P1", "P2", "P3"}
|
||||||
|
|
||||||
|
var validOpsAlertSeveritySet = func() map[string]struct{} {
|
||||||
|
set := make(map[string]struct{}, len(validOpsAlertSeverities))
|
||||||
|
for _, v := range validOpsAlertSeverities {
|
||||||
|
set[v] = struct{}{}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}()
|
||||||
|
|
||||||
|
type opsAlertRuleValidatedInput struct {
|
||||||
|
Name string
|
||||||
|
MetricType string
|
||||||
|
Operator string
|
||||||
|
Threshold float64
|
||||||
|
|
||||||
|
Severity string
|
||||||
|
|
||||||
|
WindowMinutes int
|
||||||
|
SustainedMinutes int
|
||||||
|
CooldownMinutes int
|
||||||
|
|
||||||
|
Enabled bool
|
||||||
|
NotifyEmail bool
|
||||||
|
|
||||||
|
WindowProvided bool
|
||||||
|
SustainedProvided bool
|
||||||
|
CooldownProvided bool
|
||||||
|
SeverityProvided bool
|
||||||
|
EnabledProvided bool
|
||||||
|
NotifyProvided bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPercentOrRateMetric(metricType string) bool {
|
||||||
|
switch metricType {
|
||||||
|
case "success_rate",
|
||||||
|
"error_rate",
|
||||||
|
"upstream_error_rate",
|
||||||
|
"cpu_usage_percent",
|
||||||
|
"memory_usage_percent":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateOpsAlertRulePayload(raw map[string]json.RawMessage) (*opsAlertRuleValidatedInput, error) {
|
||||||
|
if raw == nil {
|
||||||
|
return nil, fmt.Errorf("invalid request body")
|
||||||
|
}
|
||||||
|
|
||||||
|
requiredFields := []string{"name", "metric_type", "operator", "threshold"}
|
||||||
|
for _, field := range requiredFields {
|
||||||
|
if _, ok := raw[field]; !ok {
|
||||||
|
return nil, fmt.Errorf("%s is required", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var name string
|
||||||
|
if err := json.Unmarshal(raw["name"], &name); err != nil || strings.TrimSpace(name) == "" {
|
||||||
|
return nil, fmt.Errorf("name is required")
|
||||||
|
}
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
|
||||||
|
var metricType string
|
||||||
|
if err := json.Unmarshal(raw["metric_type"], &metricType); err != nil || strings.TrimSpace(metricType) == "" {
|
||||||
|
return nil, fmt.Errorf("metric_type is required")
|
||||||
|
}
|
||||||
|
metricType = strings.TrimSpace(metricType)
|
||||||
|
if _, ok := validOpsAlertMetricTypeSet[metricType]; !ok {
|
||||||
|
return nil, fmt.Errorf("metric_type must be one of: %s", strings.Join(validOpsAlertMetricTypes, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
var operator string
|
||||||
|
if err := json.Unmarshal(raw["operator"], &operator); err != nil || strings.TrimSpace(operator) == "" {
|
||||||
|
return nil, fmt.Errorf("operator is required")
|
||||||
|
}
|
||||||
|
operator = strings.TrimSpace(operator)
|
||||||
|
if _, ok := validOpsAlertOperatorSet[operator]; !ok {
|
||||||
|
return nil, fmt.Errorf("operator must be one of: %s", strings.Join(validOpsAlertOperators, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
var threshold float64
|
||||||
|
if err := json.Unmarshal(raw["threshold"], &threshold); err != nil {
|
||||||
|
return nil, fmt.Errorf("threshold must be a number")
|
||||||
|
}
|
||||||
|
if math.IsNaN(threshold) || math.IsInf(threshold, 0) {
|
||||||
|
return nil, fmt.Errorf("threshold must be a finite number")
|
||||||
|
}
|
||||||
|
if isPercentOrRateMetric(metricType) {
|
||||||
|
if threshold < 0 || threshold > 100 {
|
||||||
|
return nil, fmt.Errorf("threshold must be between 0 and 100 for metric_type %s", metricType)
|
||||||
|
}
|
||||||
|
} else if threshold < 0 {
|
||||||
|
return nil, fmt.Errorf("threshold must be >= 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
validated := &opsAlertRuleValidatedInput{
|
||||||
|
Name: name,
|
||||||
|
MetricType: metricType,
|
||||||
|
Operator: operator,
|
||||||
|
Threshold: threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := raw["severity"]; ok {
|
||||||
|
validated.SeverityProvided = true
|
||||||
|
var sev string
|
||||||
|
if err := json.Unmarshal(v, &sev); err != nil {
|
||||||
|
return nil, fmt.Errorf("severity must be a string")
|
||||||
|
}
|
||||||
|
sev = strings.ToUpper(strings.TrimSpace(sev))
|
||||||
|
if sev != "" {
|
||||||
|
if _, ok := validOpsAlertSeveritySet[sev]; !ok {
|
||||||
|
return nil, fmt.Errorf("severity must be one of: %s", strings.Join(validOpsAlertSeverities, ", "))
|
||||||
|
}
|
||||||
|
validated.Severity = sev
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if validated.Severity == "" {
|
||||||
|
validated.Severity = "P2"
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := raw["enabled"]; ok {
|
||||||
|
validated.EnabledProvided = true
|
||||||
|
if err := json.Unmarshal(v, &validated.Enabled); err != nil {
|
||||||
|
return nil, fmt.Errorf("enabled must be a boolean")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
validated.Enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := raw["notify_email"]; ok {
|
||||||
|
validated.NotifyProvided = true
|
||||||
|
if err := json.Unmarshal(v, &validated.NotifyEmail); err != nil {
|
||||||
|
return nil, fmt.Errorf("notify_email must be a boolean")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
validated.NotifyEmail = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := raw["window_minutes"]; ok {
|
||||||
|
validated.WindowProvided = true
|
||||||
|
if err := json.Unmarshal(v, &validated.WindowMinutes); err != nil {
|
||||||
|
return nil, fmt.Errorf("window_minutes must be an integer")
|
||||||
|
}
|
||||||
|
switch validated.WindowMinutes {
|
||||||
|
case 1, 5, 60:
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("window_minutes must be one of: 1, 5, 60")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
validated.WindowMinutes = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := raw["sustained_minutes"]; ok {
|
||||||
|
validated.SustainedProvided = true
|
||||||
|
if err := json.Unmarshal(v, &validated.SustainedMinutes); err != nil {
|
||||||
|
return nil, fmt.Errorf("sustained_minutes must be an integer")
|
||||||
|
}
|
||||||
|
if validated.SustainedMinutes < 1 || validated.SustainedMinutes > 1440 {
|
||||||
|
return nil, fmt.Errorf("sustained_minutes must be between 1 and 1440")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
validated.SustainedMinutes = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := raw["cooldown_minutes"]; ok {
|
||||||
|
validated.CooldownProvided = true
|
||||||
|
if err := json.Unmarshal(v, &validated.CooldownMinutes); err != nil {
|
||||||
|
return nil, fmt.Errorf("cooldown_minutes must be an integer")
|
||||||
|
}
|
||||||
|
if validated.CooldownMinutes < 0 || validated.CooldownMinutes > 1440 {
|
||||||
|
return nil, fmt.Errorf("cooldown_minutes must be between 0 and 1440")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
validated.CooldownMinutes = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return validated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAlertRules returns all ops alert rules.
|
||||||
|
// GET /api/v1/admin/ops/alert-rules
|
||||||
|
func (h *OpsHandler) ListAlertRules(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := h.opsService.ListAlertRules(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAlertRule creates an ops alert rule.
|
||||||
|
// POST /api/v1/admin/ops/alert-rules
|
||||||
|
func (h *OpsHandler) CreateAlertRule(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw map[string]json.RawMessage
|
||||||
|
if err := c.ShouldBindBodyWith(&raw, binding.JSON); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
validated, err := validateOpsAlertRulePayload(raw)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var rule service.OpsAlertRule
|
||||||
|
if err := c.ShouldBindBodyWith(&rule, binding.JSON); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Name = validated.Name
|
||||||
|
rule.MetricType = validated.MetricType
|
||||||
|
rule.Operator = validated.Operator
|
||||||
|
rule.Threshold = validated.Threshold
|
||||||
|
rule.WindowMinutes = validated.WindowMinutes
|
||||||
|
rule.SustainedMinutes = validated.SustainedMinutes
|
||||||
|
rule.CooldownMinutes = validated.CooldownMinutes
|
||||||
|
rule.Severity = validated.Severity
|
||||||
|
rule.Enabled = validated.Enabled
|
||||||
|
rule.NotifyEmail = validated.NotifyEmail
|
||||||
|
|
||||||
|
created, err := h.opsService.CreateAlertRule(c.Request.Context(), &rule)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, created)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAlertRule updates an existing ops alert rule.
|
||||||
|
// PUT /api/v1/admin/ops/alert-rules/:id
|
||||||
|
func (h *OpsHandler) UpdateAlertRule(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid rule ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw map[string]json.RawMessage
|
||||||
|
if err := c.ShouldBindBodyWith(&raw, binding.JSON); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
validated, err := validateOpsAlertRulePayload(raw)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var rule service.OpsAlertRule
|
||||||
|
if err := c.ShouldBindBodyWith(&rule, binding.JSON); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.ID = id
|
||||||
|
rule.Name = validated.Name
|
||||||
|
rule.MetricType = validated.MetricType
|
||||||
|
rule.Operator = validated.Operator
|
||||||
|
rule.Threshold = validated.Threshold
|
||||||
|
rule.WindowMinutes = validated.WindowMinutes
|
||||||
|
rule.SustainedMinutes = validated.SustainedMinutes
|
||||||
|
rule.CooldownMinutes = validated.CooldownMinutes
|
||||||
|
rule.Severity = validated.Severity
|
||||||
|
rule.Enabled = validated.Enabled
|
||||||
|
rule.NotifyEmail = validated.NotifyEmail
|
||||||
|
|
||||||
|
updated, err := h.opsService.UpdateAlertRule(c.Request.Context(), &rule)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAlertRule deletes an ops alert rule.
|
||||||
|
// DELETE /api/v1/admin/ops/alert-rules/:id
|
||||||
|
func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid rule ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.opsService.DeleteAlertRule(c.Request.Context(), id); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"deleted": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAlertEvents lists recent ops alert events.
|
||||||
|
// GET /api/v1/admin/ops/alert-events
|
||||||
|
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := 100
|
||||||
|
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
|
||||||
|
n, err := strconv.Atoi(raw)
|
||||||
|
if err != nil || n <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid limit")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
limit = n
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsAlertEventFilter{
|
||||||
|
Limit: limit,
|
||||||
|
Status: strings.TrimSpace(c.Query("status")),
|
||||||
|
Severity: strings.TrimSpace(c.Query("severity")),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional global filter support (platform/group/time range).
|
||||||
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
|
filter.Platform = platform
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
if startTime, endTime, err := parseOpsTimeRange(c, "24h"); err == nil {
|
||||||
|
// Only apply when explicitly provided to avoid surprising default narrowing.
|
||||||
|
if strings.TrimSpace(c.Query("start_time")) != "" || strings.TrimSpace(c.Query("end_time")) != "" || strings.TrimSpace(c.Query("time_range")) != "" {
|
||||||
|
filter.StartTime = &startTime
|
||||||
|
filter.EndTime = &endTime
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := h.opsService.ListAlertEvents(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, events)
|
||||||
|
}
|
||||||
243
backend/internal/handler/admin/ops_dashboard_handler.go
Normal file
243
backend/internal/handler/admin/ops_dashboard_handler.go
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetDashboardOverview returns vNext ops dashboard overview (raw path).
|
||||||
|
// GET /api/v1/admin/ops/dashboard/overview
|
||||||
|
func (h *OpsHandler) GetDashboardOverview(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsDashboardFilter{
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: strings.TrimSpace(c.Query("platform")),
|
||||||
|
QueryMode: parseOpsQueryMode(c),
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := h.opsService.GetDashboardOverview(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDashboardThroughputTrend returns throughput time series (raw path).
|
||||||
|
// GET /api/v1/admin/ops/dashboard/throughput-trend
|
||||||
|
func (h *OpsHandler) GetDashboardThroughputTrend(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsDashboardFilter{
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: strings.TrimSpace(c.Query("platform")),
|
||||||
|
QueryMode: parseOpsQueryMode(c),
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
|
||||||
|
data, err := h.opsService.GetThroughputTrend(c.Request.Context(), filter, bucketSeconds)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDashboardLatencyHistogram returns the latency distribution histogram (success requests).
|
||||||
|
// GET /api/v1/admin/ops/dashboard/latency-histogram
|
||||||
|
func (h *OpsHandler) GetDashboardLatencyHistogram(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsDashboardFilter{
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: strings.TrimSpace(c.Query("platform")),
|
||||||
|
QueryMode: parseOpsQueryMode(c),
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := h.opsService.GetLatencyHistogram(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDashboardErrorTrend returns error counts time series (raw path).
|
||||||
|
// GET /api/v1/admin/ops/dashboard/error-trend
|
||||||
|
func (h *OpsHandler) GetDashboardErrorTrend(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsDashboardFilter{
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: strings.TrimSpace(c.Query("platform")),
|
||||||
|
QueryMode: parseOpsQueryMode(c),
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
|
||||||
|
data, err := h.opsService.GetErrorTrend(c.Request.Context(), filter, bucketSeconds)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDashboardErrorDistribution returns error distribution by status code (raw path).
|
||||||
|
// GET /api/v1/admin/ops/dashboard/error-distribution
|
||||||
|
func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsDashboardFilter{
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: strings.TrimSpace(c.Query("platform")),
|
||||||
|
QueryMode: parseOpsQueryMode(c),
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := h.opsService.GetErrorDistribution(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pickThroughputBucketSeconds(window time.Duration) int {
|
||||||
|
// Keep buckets predictable and avoid huge responses.
|
||||||
|
switch {
|
||||||
|
case window <= 2*time.Hour:
|
||||||
|
return 60
|
||||||
|
case window <= 24*time.Hour:
|
||||||
|
return 300
|
||||||
|
default:
|
||||||
|
return 3600
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpsQueryMode(c *gin.Context) service.OpsQueryMode {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(c.Query("mode"))
|
||||||
|
if raw == "" {
|
||||||
|
// Empty means "use server default" (DB setting ops_query_mode_default).
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return service.ParseOpsQueryMode(raw)
|
||||||
|
}
|
||||||
364
backend/internal/handler/admin/ops_handler.go
Normal file
364
backend/internal/handler/admin/ops_handler.go
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpsHandler struct {
|
||||||
|
opsService *service.OpsService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
|
||||||
|
return &OpsHandler{opsService: opsService}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetErrorLogs lists ops error logs.
|
||||||
|
// GET /api/v1/admin/ops/errors
|
||||||
|
func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
// Ops list can be larger than standard admin tables.
|
||||||
|
if pageSize > 500 {
|
||||||
|
pageSize = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsErrorLogFilter{
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
}
|
||||||
|
if !startTime.IsZero() {
|
||||||
|
filter.StartTime = &startTime
|
||||||
|
}
|
||||||
|
if !endTime.IsZero() {
|
||||||
|
filter.EndTime = &endTime
|
||||||
|
}
|
||||||
|
|
||||||
|
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
|
||||||
|
filter.Platform = platform
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid account_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.AccountID = &id
|
||||||
|
}
|
||||||
|
if phase := strings.TrimSpace(c.Query("phase")); phase != "" {
|
||||||
|
filter.Phase = phase
|
||||||
|
}
|
||||||
|
if q := strings.TrimSpace(c.Query("q")); q != "" {
|
||||||
|
filter.Query = q
|
||||||
|
}
|
||||||
|
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
|
||||||
|
parts := strings.Split(statusCodesStr, ",")
|
||||||
|
out := make([]int, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
p := strings.TrimSpace(part)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(p)
|
||||||
|
if err != nil || n < 0 {
|
||||||
|
response.BadRequest(c, "Invalid status_codes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, n)
|
||||||
|
}
|
||||||
|
filter.StatusCodes = out
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetErrorLogByID returns a single error log detail.
|
||||||
|
// GET /api/v1/admin/ops/errors/:id
|
||||||
|
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRequestDetails returns a request-level list (success + error) for drill-down.
|
||||||
|
// GET /api/v1/admin/ops/requests
|
||||||
|
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize := response.ParsePagination(c)
|
||||||
|
if pageSize > 100 {
|
||||||
|
pageSize = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsRequestDetailFilter{
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
StartTime: &startTime,
|
||||||
|
EndTime: &endTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
filter.Kind = strings.TrimSpace(c.Query("kind"))
|
||||||
|
filter.Platform = strings.TrimSpace(c.Query("platform"))
|
||||||
|
filter.Model = strings.TrimSpace(c.Query("model"))
|
||||||
|
filter.RequestID = strings.TrimSpace(c.Query("request_id"))
|
||||||
|
filter.Query = strings.TrimSpace(c.Query("q"))
|
||||||
|
filter.Sort = strings.TrimSpace(c.Query("sort"))
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid user_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.UserID = &id
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("api_key_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid api_key_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.APIKeyID = &id
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid account_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.AccountID = &id
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.GroupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(c.Query("min_duration_ms")); v != "" {
|
||||||
|
parsed, err := strconv.Atoi(v)
|
||||||
|
if err != nil || parsed < 0 {
|
||||||
|
response.BadRequest(c, "Invalid min_duration_ms")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.MinDurationMs = &parsed
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(c.Query("max_duration_ms")); v != "" {
|
||||||
|
parsed, err := strconv.Atoi(v)
|
||||||
|
if err != nil || parsed < 0 {
|
||||||
|
response.BadRequest(c, "Invalid max_duration_ms")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter.MaxDurationMs = &parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := h.opsService.ListRequestDetails(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
// Invalid sort/kind/platform etc should be a bad request; keep it simple.
|
||||||
|
if strings.Contains(strings.ToLower(err.Error()), "invalid") {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Error(c, http.StatusInternalServerError, "Failed to list request details")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, out.Items, out.Total, out.Page, out.PageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
type opsRetryRequest struct {
|
||||||
|
Mode string `json:"mode"`
|
||||||
|
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryErrorRequest retries a failed request using stored request_body.
|
||||||
|
// POST /api/v1/admin/ops/errors/:id/retry
|
||||||
|
func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok || subject.UserID <= 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idStr := strings.TrimSpace(c.Param("id"))
|
||||||
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid error id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req := opsRetryRequest{Mode: service.OpsRetryModeClient}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Mode) == "" {
|
||||||
|
req.Mode = service.OpsRetryModeClient
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
|
||||||
|
startStr := strings.TrimSpace(c.Query("start_time"))
|
||||||
|
endStr := strings.TrimSpace(c.Query("end_time"))
|
||||||
|
|
||||||
|
parseTS := func(s string) (time.Time, error) {
|
||||||
|
if s == "" {
|
||||||
|
return time.Time{}, nil
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
return time.Parse(time.RFC3339, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
start, err := parseTS(startStr)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, time.Time{}, err
|
||||||
|
}
|
||||||
|
end, err := parseTS(endStr)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, time.Time{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// start/end explicitly provided (even partially)
|
||||||
|
if startStr != "" || endStr != "" {
|
||||||
|
if end.IsZero() {
|
||||||
|
end = time.Now()
|
||||||
|
}
|
||||||
|
if start.IsZero() {
|
||||||
|
dur, _ := parseOpsDuration(defaultRange)
|
||||||
|
start = end.Add(-dur)
|
||||||
|
}
|
||||||
|
if start.After(end) {
|
||||||
|
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: start_time must be <= end_time")
|
||||||
|
}
|
||||||
|
if end.Sub(start) > 30*24*time.Hour {
|
||||||
|
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: max window is 30 days")
|
||||||
|
}
|
||||||
|
return start, end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// time_range fallback
|
||||||
|
tr := strings.TrimSpace(c.Query("time_range"))
|
||||||
|
if tr == "" {
|
||||||
|
tr = defaultRange
|
||||||
|
}
|
||||||
|
dur, ok := parseOpsDuration(tr)
|
||||||
|
if !ok {
|
||||||
|
dur, _ = parseOpsDuration(defaultRange)
|
||||||
|
}
|
||||||
|
|
||||||
|
end = time.Now()
|
||||||
|
start = end.Add(-dur)
|
||||||
|
if end.Sub(start) > 30*24*time.Hour {
|
||||||
|
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: max window is 30 days")
|
||||||
|
}
|
||||||
|
return start, end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpsDuration(v string) (time.Duration, bool) {
|
||||||
|
switch strings.TrimSpace(v) {
|
||||||
|
case "5m":
|
||||||
|
return 5 * time.Minute, true
|
||||||
|
case "30m":
|
||||||
|
return 30 * time.Minute, true
|
||||||
|
case "1h":
|
||||||
|
return time.Hour, true
|
||||||
|
case "6h":
|
||||||
|
return 6 * time.Hour, true
|
||||||
|
case "24h":
|
||||||
|
return 24 * time.Hour, true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
213
backend/internal/handler/admin/ops_realtime_handler.go
Normal file
213
backend/internal/handler/admin/ops_realtime_handler.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetConcurrencyStats returns real-time concurrency usage aggregated by platform/group/account.
|
||||||
|
// GET /api/v1/admin/ops/concurrency
|
||||||
|
func (h *OpsHandler) GetConcurrencyStats(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"enabled": false,
|
||||||
|
"platform": map[string]*service.PlatformConcurrencyInfo{},
|
||||||
|
"group": map[int64]*service.GroupConcurrencyInfo{},
|
||||||
|
"account": map[int64]*service.AccountConcurrencyInfo{},
|
||||||
|
"timestamp": time.Now().UTC(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
platformFilter := strings.TrimSpace(c.Query("platform"))
|
||||||
|
var groupID *int64
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
platform, group, account, collectedAt, err := h.opsService.GetConcurrencyStats(c.Request.Context(), platformFilter, groupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := gin.H{
|
||||||
|
"enabled": true,
|
||||||
|
"platform": platform,
|
||||||
|
"group": group,
|
||||||
|
"account": account,
|
||||||
|
}
|
||||||
|
if collectedAt != nil {
|
||||||
|
payload["timestamp"] = collectedAt.UTC()
|
||||||
|
}
|
||||||
|
response.Success(c, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountAvailability returns account availability statistics.
|
||||||
|
// GET /api/v1/admin/ops/account-availability
|
||||||
|
//
|
||||||
|
// Query params:
|
||||||
|
// - platform: optional
|
||||||
|
// - group_id: optional
|
||||||
|
func (h *OpsHandler) GetAccountAvailability(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"enabled": false,
|
||||||
|
"platform": map[string]*service.PlatformAvailability{},
|
||||||
|
"group": map[int64]*service.GroupAvailability{},
|
||||||
|
"account": map[int64]*service.AccountAvailability{},
|
||||||
|
"timestamp": time.Now().UTC(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
platform := strings.TrimSpace(c.Query("platform"))
|
||||||
|
var groupID *int64
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
platformStats, groupStats, accountStats, collectedAt, err := h.opsService.GetAccountAvailabilityStats(c.Request.Context(), platform, groupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := gin.H{
|
||||||
|
"enabled": true,
|
||||||
|
"platform": platformStats,
|
||||||
|
"group": groupStats,
|
||||||
|
"account": accountStats,
|
||||||
|
}
|
||||||
|
if collectedAt != nil {
|
||||||
|
payload["timestamp"] = collectedAt.UTC()
|
||||||
|
}
|
||||||
|
response.Success(c, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpsRealtimeWindow(v string) (time.Duration, string, bool) {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||||
|
case "", "1min", "1m":
|
||||||
|
return 1 * time.Minute, "1min", true
|
||||||
|
case "5min", "5m":
|
||||||
|
return 5 * time.Minute, "5min", true
|
||||||
|
case "30min", "30m":
|
||||||
|
return 30 * time.Minute, "30min", true
|
||||||
|
case "1h", "60m", "60min":
|
||||||
|
return 1 * time.Hour, "1h", true
|
||||||
|
default:
|
||||||
|
return 0, "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRealtimeTrafficSummary returns QPS/TPS current/peak/avg for the selected window.
|
||||||
|
// GET /api/v1/admin/ops/realtime-traffic
|
||||||
|
//
|
||||||
|
// Query params:
|
||||||
|
// - window: 1min|5min|30min|1h (default: 1min)
|
||||||
|
// - platform: optional
|
||||||
|
// - group_id: optional
|
||||||
|
func (h *OpsHandler) GetRealtimeTrafficSummary(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
windowDur, windowLabel, ok := parseOpsRealtimeWindow(c.Query("window"))
|
||||||
|
if !ok {
|
||||||
|
response.BadRequest(c, "Invalid window")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
platform := strings.TrimSpace(c.Query("platform"))
|
||||||
|
var groupID *int64
|
||||||
|
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||||
|
id, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil || id <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid group_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID = &id
|
||||||
|
}
|
||||||
|
|
||||||
|
endTime := time.Now().UTC()
|
||||||
|
startTime := endTime.Add(-windowDur)
|
||||||
|
|
||||||
|
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
|
||||||
|
disabledSummary := &service.OpsRealtimeTrafficSummary{
|
||||||
|
Window: windowLabel,
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: platform,
|
||||||
|
GroupID: groupID,
|
||||||
|
QPS: service.OpsRateSummary{},
|
||||||
|
TPS: service.OpsRateSummary{},
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"enabled": false,
|
||||||
|
"summary": disabledSummary,
|
||||||
|
"timestamp": endTime,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := &service.OpsDashboardFilter{
|
||||||
|
StartTime: startTime,
|
||||||
|
EndTime: endTime,
|
||||||
|
Platform: platform,
|
||||||
|
GroupID: groupID,
|
||||||
|
QueryMode: service.OpsQueryModeRaw,
|
||||||
|
}
|
||||||
|
|
||||||
|
summary, err := h.opsService.GetRealtimeTrafficSummary(c.Request.Context(), filter)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if summary != nil {
|
||||||
|
summary.Window = windowLabel
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"enabled": true,
|
||||||
|
"summary": summary,
|
||||||
|
"timestamp": endTime,
|
||||||
|
})
|
||||||
|
}
|
||||||
194
backend/internal/handler/admin/ops_settings_handler.go
Normal file
194
backend/internal/handler/admin/ops_settings_handler.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetEmailNotificationConfig returns Ops email notification config (DB-backed).
|
||||||
|
// GET /api/v1/admin/ops/email-notification/config
|
||||||
|
func (h *OpsHandler) GetEmailNotificationConfig(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := h.opsService.GetEmailNotificationConfig(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusInternalServerError, "Failed to get email notification config")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateEmailNotificationConfig updates Ops email notification config (DB-backed).
|
||||||
|
// PUT /api/v1/admin/ops/email-notification/config
|
||||||
|
func (h *OpsHandler) UpdateEmailNotificationConfig(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req service.OpsEmailNotificationConfigUpdateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.opsService.UpdateEmailNotificationConfig(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
// Most failures here are validation errors from request payload; treat as 400.
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAlertRuntimeSettings returns Ops alert evaluator runtime settings (DB-backed).
|
||||||
|
// GET /api/v1/admin/ops/runtime/alert
|
||||||
|
func (h *OpsHandler) GetAlertRuntimeSettings(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := h.opsService.GetOpsAlertRuntimeSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusInternalServerError, "Failed to get alert runtime settings")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAlertRuntimeSettings updates Ops alert evaluator runtime settings (DB-backed).
|
||||||
|
// PUT /api/v1/admin/ops/runtime/alert
|
||||||
|
func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req service.OpsAlertRuntimeSettings
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.opsService.UpdateOpsAlertRuntimeSettings(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
|
||||||
|
// GET /api/v1/admin/ops/advanced-settings
|
||||||
|
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := h.opsService.GetOpsAdvancedSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusInternalServerError, "Failed to get advanced settings")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAdvancedSettings updates Ops advanced settings (DB-backed).
|
||||||
|
// PUT /api/v1/admin/ops/advanced-settings
|
||||||
|
func (h *OpsHandler) UpdateAdvancedSettings(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req service.OpsAdvancedSettings
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.opsService.UpdateOpsAdvancedSettings(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetricThresholds returns Ops metric thresholds (DB-backed).
|
||||||
|
// GET /api/v1/admin/ops/settings/metric-thresholds
|
||||||
|
func (h *OpsHandler) GetMetricThresholds(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := h.opsService.GetMetricThresholds(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusInternalServerError, "Failed to get metric thresholds")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMetricThresholds updates Ops metric thresholds (DB-backed).
|
||||||
|
// PUT /api/v1/admin/ops/settings/metric-thresholds
|
||||||
|
func (h *OpsHandler) UpdateMetricThresholds(c *gin.Context) {
|
||||||
|
if h.opsService == nil {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req service.OpsMetricThresholds
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.opsService.UpdateMetricThresholds(c.Request.Context(), &req)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, updated)
|
||||||
|
}
|
||||||
771
backend/internal/handler/admin/ops_ws_handler.go
Normal file
771
backend/internal/handler/admin/ops_ws_handler.go
Normal file
@@ -0,0 +1,771 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpsWSProxyConfig struct {
|
||||||
|
TrustProxy bool
|
||||||
|
TrustedProxies []netip.Prefix
|
||||||
|
OriginPolicy string
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
envOpsWSTrustProxy = "OPS_WS_TRUST_PROXY"
|
||||||
|
envOpsWSTrustedProxies = "OPS_WS_TRUSTED_PROXIES"
|
||||||
|
envOpsWSOriginPolicy = "OPS_WS_ORIGIN_POLICY"
|
||||||
|
envOpsWSMaxConns = "OPS_WS_MAX_CONNS"
|
||||||
|
envOpsWSMaxConnsPerIP = "OPS_WS_MAX_CONNS_PER_IP"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OriginPolicyStrict = "strict"
|
||||||
|
OriginPolicyPermissive = "permissive"
|
||||||
|
)
|
||||||
|
|
||||||
|
var opsWSProxyConfig = loadOpsWSProxyConfigFromEnv()
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return isAllowedOpsWSOrigin(r)
|
||||||
|
},
|
||||||
|
// Subprotocol negotiation:
|
||||||
|
// - The frontend passes ["sub2api-admin", "jwt.<token>"].
|
||||||
|
// - We always select "sub2api-admin" so the token is never echoed back in the handshake response.
|
||||||
|
Subprotocols: []string{"sub2api-admin"},
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
qpsWSPushInterval = 2 * time.Second
|
||||||
|
qpsWSRefreshInterval = 5 * time.Second
|
||||||
|
qpsWSRequestCountWindow = 1 * time.Minute
|
||||||
|
|
||||||
|
defaultMaxWSConns = 100
|
||||||
|
defaultMaxWSConnsPerIP = 20
|
||||||
|
)
|
||||||
|
|
||||||
|
var wsConnCount atomic.Int32
|
||||||
|
var wsConnCountByIP sync.Map // map[string]*atomic.Int32
|
||||||
|
|
||||||
|
const qpsWSIdleStopDelay = 30 * time.Second
|
||||||
|
|
||||||
|
const (
|
||||||
|
opsWSCloseRealtimeDisabled = 4001
|
||||||
|
)
|
||||||
|
|
||||||
|
var qpsWSIdleStopMu sync.Mutex
|
||||||
|
var qpsWSIdleStopTimer *time.Timer
|
||||||
|
|
||||||
|
func cancelQPSWSIdleStop() {
|
||||||
|
qpsWSIdleStopMu.Lock()
|
||||||
|
if qpsWSIdleStopTimer != nil {
|
||||||
|
qpsWSIdleStopTimer.Stop()
|
||||||
|
qpsWSIdleStopTimer = nil
|
||||||
|
}
|
||||||
|
qpsWSIdleStopMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func scheduleQPSWSIdleStop() {
|
||||||
|
qpsWSIdleStopMu.Lock()
|
||||||
|
if qpsWSIdleStopTimer != nil {
|
||||||
|
qpsWSIdleStopMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
qpsWSIdleStopTimer = time.AfterFunc(qpsWSIdleStopDelay, func() {
|
||||||
|
// Only stop if truly idle at fire time.
|
||||||
|
if wsConnCount.Load() == 0 {
|
||||||
|
qpsWSCache.Stop()
|
||||||
|
}
|
||||||
|
qpsWSIdleStopMu.Lock()
|
||||||
|
qpsWSIdleStopTimer = nil
|
||||||
|
qpsWSIdleStopMu.Unlock()
|
||||||
|
})
|
||||||
|
qpsWSIdleStopMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
type opsWSRuntimeLimits struct {
|
||||||
|
MaxConns int32
|
||||||
|
MaxConnsPerIP int32
|
||||||
|
}
|
||||||
|
|
||||||
|
var opsWSLimits = loadOpsWSRuntimeLimitsFromEnv()
|
||||||
|
|
||||||
|
const (
|
||||||
|
qpsWSWriteTimeout = 10 * time.Second
|
||||||
|
qpsWSPongWait = 60 * time.Second
|
||||||
|
qpsWSPingInterval = 30 * time.Second
|
||||||
|
|
||||||
|
// We don't expect clients to send application messages; we only read to process control frames (Pong/Close).
|
||||||
|
qpsWSMaxReadBytes = 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
type opsWSQPSCache struct {
|
||||||
|
refreshInterval time.Duration
|
||||||
|
requestCountWindow time.Duration
|
||||||
|
|
||||||
|
lastUpdatedUnixNano atomic.Int64
|
||||||
|
payload atomic.Value // []byte
|
||||||
|
|
||||||
|
opsService *service.OpsService
|
||||||
|
cancel context.CancelFunc
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var qpsWSCache = &opsWSQPSCache{
|
||||||
|
refreshInterval: qpsWSRefreshInterval,
|
||||||
|
requestCountWindow: qpsWSRequestCountWindow,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *opsWSQPSCache) start(opsService *service.OpsService) {
|
||||||
|
if c == nil || opsService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.running {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a previous refresh loop is currently stopping, wait for it to fully exit.
|
||||||
|
done := c.done
|
||||||
|
if done != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.done == done && !c.running {
|
||||||
|
c.done = nil
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.opsService = opsService
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
c.cancel = cancel
|
||||||
|
c.done = make(chan struct{})
|
||||||
|
done = c.done
|
||||||
|
c.running = true
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
c.refreshLoop(ctx)
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the background refresh loop.
|
||||||
|
// It is safe to call multiple times.
|
||||||
|
func (c *opsWSQPSCache) Stop() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if !c.running {
|
||||||
|
done := c.done
|
||||||
|
c.mu.Unlock()
|
||||||
|
if done != nil {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cancel := c.cancel
|
||||||
|
c.cancel = nil
|
||||||
|
c.running = false
|
||||||
|
c.opsService = nil
|
||||||
|
done := c.done
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if cancel != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
if done != nil {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.done == done && !c.running {
|
||||||
|
c.done = nil
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *opsWSQPSCache) refreshLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(c.refreshInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
c.refresh(ctx)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
c.refresh(ctx)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
opsService := c.opsService
|
||||||
|
c.mu.Unlock()
|
||||||
|
if opsService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if parentCtx == nil {
|
||||||
|
parentCtx = context.Background()
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(parentCtx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
|
||||||
|
if err != nil || stats == nil {
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCount := stats.SuccessCount + stats.ErrorCountTotal
|
||||||
|
qps := 0.0
|
||||||
|
tps := 0.0
|
||||||
|
if c.requestCountWindow > 0 {
|
||||||
|
seconds := c.requestCountWindow.Seconds()
|
||||||
|
qps = roundTo1DP(float64(requestCount) / seconds)
|
||||||
|
tps = roundTo1DP(float64(stats.TokenConsumed) / seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := gin.H{
|
||||||
|
"type": "qps_update",
|
||||||
|
"timestamp": now.Format(time.RFC3339),
|
||||||
|
"data": gin.H{
|
||||||
|
"qps": qps,
|
||||||
|
"tps": tps,
|
||||||
|
"request_count": requestCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.payload.Store(msg)
|
||||||
|
c.lastUpdatedUnixNano.Store(now.UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundTo1DP(v float64) float64 {
|
||||||
|
return math.Round(v*10) / 10
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *opsWSQPSCache) getPayload() []byte {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cached, ok := c.payload.Load().([]byte); ok && cached != nil {
|
||||||
|
return cached
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeWS(conn *websocket.Conn, code int, reason string) {
|
||||||
|
if conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg := websocket.FormatCloseMessage(code, reason)
|
||||||
|
_ = conn.WriteControl(websocket.CloseMessage, msg, time.Now().Add(qpsWSWriteTimeout))
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// QPSWSHandler handles realtime QPS push via WebSocket.
|
||||||
|
// GET /api/v1/admin/ops/ws/qps
|
||||||
|
func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||||
|
clientIP := requestClientIP(c.Request)
|
||||||
|
|
||||||
|
if h == nil || h.opsService == nil {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If realtime monitoring is disabled, prefer a successful WS upgrade followed by a clean close
|
||||||
|
// with a deterministic close code. This prevents clients from spinning on 404/1006 reconnect loops.
|
||||||
|
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
|
||||||
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cancelQPSWSIdleStop()
|
||||||
|
// Lazily start the background refresh loop so unit tests that never hit the
|
||||||
|
// websocket route don't spawn goroutines that depend on DB/Redis stubs.
|
||||||
|
qpsWSCache.start(h.opsService)
|
||||||
|
|
||||||
|
// Reserve a global slot before upgrading the connection to keep the limit strict.
|
||||||
|
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
|
||||||
|
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if wsConnCount.Add(-1) == 0 {
|
||||||
|
scheduleQPSWSIdleStop()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
|
||||||
|
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
|
||||||
|
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer releaseOpsWSIPSlot(clientIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpsWS] upgrade failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
handleQPSWebSocket(c.Request.Context(), conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryAcquireOpsWSTotalSlot(limit int32) bool {
|
||||||
|
if limit <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
current := wsConnCount.Load()
|
||||||
|
if current >= limit {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if wsConnCount.CompareAndSwap(current, current+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
|
||||||
|
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
|
||||||
|
counter, ok := v.(*atomic.Int32)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
current := counter.Load()
|
||||||
|
if current >= limit {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if counter.CompareAndSwap(current, current+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseOpsWSIPSlot(clientIP string) {
|
||||||
|
if strings.TrimSpace(clientIP) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := wsConnCountByIP.Load(clientIP)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
counter, ok := v.(*atomic.Int32)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next := counter.Add(-1)
|
||||||
|
if next <= 0 {
|
||||||
|
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
|
||||||
|
wsConnCountByIP.Delete(clientIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||||
|
if conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var closeOnce sync.Once
|
||||||
|
closeConn := func() {
|
||||||
|
closeOnce.Do(func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
closeFrameCh := make(chan []byte, 1)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn.SetReadLimit(qpsWSMaxReadBytes)
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
|
||||||
|
log.Printf("[OpsWS] set read deadline failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.SetPongHandler(func(string) error {
|
||||||
|
return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait))
|
||||||
|
})
|
||||||
|
conn.SetCloseHandler(func(code int, text string) error {
|
||||||
|
select {
|
||||||
|
case closeFrameCh <- websocket.FormatCloseMessage(code, text):
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, _, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||||
|
log.Printf("[OpsWS] read failed: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Push QPS data every 2 seconds (values are globally cached and refreshed at most once per qpsWSRefreshInterval).
|
||||||
|
pushTicker := time.NewTicker(qpsWSPushInterval)
|
||||||
|
defer pushTicker.Stop()
|
||||||
|
|
||||||
|
// Heartbeat ping every 30 seconds.
|
||||||
|
pingTicker := time.NewTicker(qpsWSPingInterval)
|
||||||
|
defer pingTicker.Stop()
|
||||||
|
|
||||||
|
writeWithTimeout := func(messageType int, data []byte) error {
|
||||||
|
if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return conn.WriteMessage(messageType, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendClose := func(closeFrame []byte) {
|
||||||
|
if closeFrame == nil {
|
||||||
|
closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
|
||||||
|
}
|
||||||
|
_ = writeWithTimeout(websocket.CloseMessage, closeFrame)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-pushTicker.C:
|
||||||
|
msg := qpsWSCache.getPayload()
|
||||||
|
if msg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
|
||||||
|
log.Printf("[OpsWS] write failed: %v", err)
|
||||||
|
cancel()
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-pingTicker.C:
|
||||||
|
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
|
||||||
|
log.Printf("[OpsWS] ping failed: %v", err)
|
||||||
|
cancel()
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case closeFrame := <-closeFrameCh:
|
||||||
|
sendClose(closeFrame)
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
var closeFrame []byte
|
||||||
|
select {
|
||||||
|
case closeFrame = <-closeFrameCh:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
sendClose(closeFrame)
|
||||||
|
|
||||||
|
closeConn()
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAllowedOpsWSOrigin(r *http.Request) bool {
|
||||||
|
if r == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||||
|
if origin == "" {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(opsWSProxyConfig.OriginPolicy)) {
|
||||||
|
case OriginPolicyStrict:
|
||||||
|
return false
|
||||||
|
case OriginPolicyPermissive, "":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(origin)
|
||||||
|
if err != nil || parsed.Hostname() == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
originHost := strings.ToLower(parsed.Hostname())
|
||||||
|
|
||||||
|
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
|
||||||
|
reqHost := hostWithoutPort(r.Host)
|
||||||
|
if trustProxyHeaders {
|
||||||
|
xfHost := strings.TrimSpace(r.Header.Get("X-Forwarded-Host"))
|
||||||
|
if xfHost != "" {
|
||||||
|
xfHost = strings.TrimSpace(strings.Split(xfHost, ",")[0])
|
||||||
|
if xfHost != "" {
|
||||||
|
reqHost = hostWithoutPort(xfHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reqHost = strings.ToLower(reqHost)
|
||||||
|
if reqHost == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return originHost == reqHost
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldTrustOpsWSProxyHeaders(r *http.Request) bool {
|
||||||
|
if r == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !opsWSProxyConfig.TrustProxy {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
peerIP, ok := requestPeerIP(r)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isAddrInTrustedProxies(peerIP, opsWSProxyConfig.TrustedProxies)
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestPeerIP(r *http.Request) (netip.Addr, bool) {
|
||||||
|
if r == nil {
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
|
||||||
|
if err != nil {
|
||||||
|
host = strings.TrimSpace(r.RemoteAddr)
|
||||||
|
}
|
||||||
|
host = strings.TrimPrefix(host, "[")
|
||||||
|
host = strings.TrimSuffix(host, "]")
|
||||||
|
if host == "" {
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
addr, err := netip.ParseAddr(host)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
return addr.Unmap(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestClientIP(r *http.Request) string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
|
||||||
|
if trustProxyHeaders {
|
||||||
|
xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
|
||||||
|
if xff != "" {
|
||||||
|
// Use the left-most entry (original client). If multiple proxies add values, they are comma-separated.
|
||||||
|
xff = strings.TrimSpace(strings.Split(xff, ",")[0])
|
||||||
|
xff = strings.TrimPrefix(xff, "[")
|
||||||
|
xff = strings.TrimSuffix(xff, "]")
|
||||||
|
if addr, err := netip.ParseAddr(xff); err == nil && addr.IsValid() {
|
||||||
|
return addr.Unmap().String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer, ok := requestPeerIP(r); ok && peer.IsValid() {
|
||||||
|
return peer.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAddrInTrustedProxies(addr netip.Addr, trusted []netip.Prefix) bool {
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, p := range trusted {
|
||||||
|
if p.Contains(addr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||||
|
cfg := OpsWSProxyConfig{
|
||||||
|
TrustProxy: true,
|
||||||
|
TrustedProxies: defaultTrustedProxies(),
|
||||||
|
OriginPolicy: OriginPolicyPermissive,
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(os.Getenv(envOpsWSTrustProxy)); v != "" {
|
||||||
|
if parsed, err := strconv.ParseBool(v); err == nil {
|
||||||
|
cfg.TrustProxy = parsed
|
||||||
|
} else {
|
||||||
|
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
|
||||||
|
prefixes, invalid := parseTrustedProxyList(raw)
|
||||||
|
if len(invalid) > 0 {
|
||||||
|
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||||
|
}
|
||||||
|
cfg.TrustedProxies = prefixes
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(os.Getenv(envOpsWSOriginPolicy)); v != "" {
|
||||||
|
normalized := strings.ToLower(v)
|
||||||
|
switch normalized {
|
||||||
|
case OriginPolicyStrict, OriginPolicyPermissive:
|
||||||
|
cfg.OriginPolicy = normalized
|
||||||
|
default:
|
||||||
|
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
|
||||||
|
cfg := opsWSRuntimeLimits{
|
||||||
|
MaxConns: defaultMaxWSConns,
|
||||||
|
MaxConnsPerIP: defaultMaxWSConnsPerIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConns)); v != "" {
|
||||||
|
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
|
||||||
|
cfg.MaxConns = int32(parsed)
|
||||||
|
} else {
|
||||||
|
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
|
||||||
|
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
|
||||||
|
cfg.MaxConnsPerIP = int32(parsed)
|
||||||
|
} else {
|
||||||
|
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultTrustedProxies() []netip.Prefix {
|
||||||
|
prefixes, _ := parseTrustedProxyList("127.0.0.0/8,::1/128")
|
||||||
|
return prefixes
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTrustedProxyList(raw string) (prefixes []netip.Prefix, invalid []string) {
|
||||||
|
for _, token := range strings.Split(raw, ",") {
|
||||||
|
item := strings.TrimSpace(token)
|
||||||
|
if item == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
p netip.Prefix
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if strings.Contains(item, "/") {
|
||||||
|
p, err = netip.ParsePrefix(item)
|
||||||
|
} else {
|
||||||
|
var addr netip.Addr
|
||||||
|
addr, err = netip.ParseAddr(item)
|
||||||
|
if err == nil {
|
||||||
|
addr = addr.Unmap()
|
||||||
|
bits := 128
|
||||||
|
if addr.Is4() {
|
||||||
|
bits = 32
|
||||||
|
}
|
||||||
|
p = netip.PrefixFrom(addr, bits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil || !p.IsValid() {
|
||||||
|
invalid = append(invalid, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixes = append(prefixes, p.Masked())
|
||||||
|
}
|
||||||
|
return prefixes, invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func hostWithoutPort(hostport string) string {
|
||||||
|
hostport = strings.TrimSpace(hostport)
|
||||||
|
if hostport == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if host, _, err := net.SplitHostPort(hostport); err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(hostport, "[") && strings.HasSuffix(hostport, "]") {
|
||||||
|
return strings.Trim(hostport, "[]")
|
||||||
|
}
|
||||||
|
parts := strings.Split(hostport, ":")
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
@@ -19,14 +19,16 @@ type SettingHandler struct {
|
|||||||
settingService *service.SettingService
|
settingService *service.SettingService
|
||||||
emailService *service.EmailService
|
emailService *service.EmailService
|
||||||
turnstileService *service.TurnstileService
|
turnstileService *service.TurnstileService
|
||||||
|
opsService *service.OpsService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSettingHandler 创建系统设置处理器
|
// NewSettingHandler 创建系统设置处理器
|
||||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
|
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
|
||||||
return &SettingHandler{
|
return &SettingHandler{
|
||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
emailService: emailService,
|
emailService: emailService,
|
||||||
turnstileService: turnstileService,
|
turnstileService: turnstileService,
|
||||||
|
opsService: opsService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,6 +41,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if ops monitoring is enabled (respects config.ops.enabled)
|
||||||
|
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
|
||||||
|
|
||||||
response.Success(c, dto.SystemSettings{
|
response.Success(c, dto.SystemSettings{
|
||||||
RegistrationEnabled: settings.RegistrationEnabled,
|
RegistrationEnabled: settings.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||||
@@ -72,6 +77,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||||
|
OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
|
||||||
|
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
|
||||||
|
OpsQueryModeDefault: settings.OpsQueryModeDefault,
|
||||||
|
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,7 +104,7 @@ type UpdateSettingsRequest struct {
|
|||||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||||
|
|
||||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
// LinuxDo Connect OAuth 登录
|
||||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||||
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
||||||
@@ -124,6 +133,12 @@ type UpdateSettingsRequest struct {
|
|||||||
// Identity patch configuration (Claude -> Gemini)
|
// Identity patch configuration (Claude -> Gemini)
|
||||||
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
||||||
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
||||||
|
|
||||||
|
// Ops monitoring (vNext)
|
||||||
|
OpsMonitoringEnabled *bool `json:"ops_monitoring_enabled"`
|
||||||
|
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
|
||||||
|
OpsQueryModeDefault *string `json:"ops_query_mode_default"`
|
||||||
|
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSettings 更新系统设置
|
// UpdateSettings 更新系统设置
|
||||||
@@ -208,6 +223,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ops metrics collector interval validation (seconds).
|
||||||
|
if req.OpsMetricsIntervalSeconds != nil {
|
||||||
|
v := *req.OpsMetricsIntervalSeconds
|
||||||
|
if v < 60 {
|
||||||
|
v = 60
|
||||||
|
}
|
||||||
|
if v > 3600 {
|
||||||
|
v = 3600
|
||||||
|
}
|
||||||
|
req.OpsMetricsIntervalSeconds = &v
|
||||||
|
}
|
||||||
|
|
||||||
settings := &service.SystemSettings{
|
settings := &service.SystemSettings{
|
||||||
RegistrationEnabled: req.RegistrationEnabled,
|
RegistrationEnabled: req.RegistrationEnabled,
|
||||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||||
@@ -241,6 +268,30 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||||
|
OpsMonitoringEnabled: func() bool {
|
||||||
|
if req.OpsMonitoringEnabled != nil {
|
||||||
|
return *req.OpsMonitoringEnabled
|
||||||
|
}
|
||||||
|
return previousSettings.OpsMonitoringEnabled
|
||||||
|
}(),
|
||||||
|
OpsRealtimeMonitoringEnabled: func() bool {
|
||||||
|
if req.OpsRealtimeMonitoringEnabled != nil {
|
||||||
|
return *req.OpsRealtimeMonitoringEnabled
|
||||||
|
}
|
||||||
|
return previousSettings.OpsRealtimeMonitoringEnabled
|
||||||
|
}(),
|
||||||
|
OpsQueryModeDefault: func() string {
|
||||||
|
if req.OpsQueryModeDefault != nil {
|
||||||
|
return *req.OpsQueryModeDefault
|
||||||
|
}
|
||||||
|
return previousSettings.OpsQueryModeDefault
|
||||||
|
}(),
|
||||||
|
OpsMetricsIntervalSeconds: func() int {
|
||||||
|
if req.OpsMetricsIntervalSeconds != nil {
|
||||||
|
return *req.OpsMetricsIntervalSeconds
|
||||||
|
}
|
||||||
|
return previousSettings.OpsMetricsIntervalSeconds
|
||||||
|
}(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||||
@@ -290,6 +341,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||||
|
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
|
||||||
|
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
|
||||||
|
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
|
||||||
|
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -411,6 +466,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
||||||
changed = append(changed, "identity_patch_prompt")
|
changed = append(changed, "identity_patch_prompt")
|
||||||
}
|
}
|
||||||
|
if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled {
|
||||||
|
changed = append(changed, "ops_monitoring_enabled")
|
||||||
|
}
|
||||||
|
if before.OpsRealtimeMonitoringEnabled != after.OpsRealtimeMonitoringEnabled {
|
||||||
|
changed = append(changed, "ops_realtime_monitoring_enabled")
|
||||||
|
}
|
||||||
|
if before.OpsQueryModeDefault != after.OpsQueryModeDefault {
|
||||||
|
changed = append(changed, "ops_query_mode_default")
|
||||||
|
}
|
||||||
|
if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds {
|
||||||
|
changed = append(changed, "ops_metrics_interval_seconds")
|
||||||
|
}
|
||||||
return changed
|
return changed
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -587,3 +654,68 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
response.Success(c, gin.H{"message": "Admin API key deleted"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStreamTimeoutSettings 获取流超时处理配置
|
||||||
|
// GET /api/v1/admin/settings/stream-timeout
|
||||||
|
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||||
|
settings, err := h.settingService.GetStreamTimeoutSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.StreamTimeoutSettings{
|
||||||
|
Enabled: settings.Enabled,
|
||||||
|
Action: settings.Action,
|
||||||
|
TempUnschedMinutes: settings.TempUnschedMinutes,
|
||||||
|
ThresholdCount: settings.ThresholdCount,
|
||||||
|
ThresholdWindowMinutes: settings.ThresholdWindowMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||||
|
type UpdateStreamTimeoutSettingsRequest struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
TempUnschedMinutes int `json:"temp_unsched_minutes"`
|
||||||
|
ThresholdCount int `json:"threshold_count"`
|
||||||
|
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStreamTimeoutSettings 更新流超时处理配置
|
||||||
|
// PUT /api/v1/admin/settings/stream-timeout
|
||||||
|
func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
|
||||||
|
var req UpdateStreamTimeoutSettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &service.StreamTimeoutSettings{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
Action: req.Action,
|
||||||
|
TempUnschedMinutes: req.TempUnschedMinutes,
|
||||||
|
ThresholdCount: req.ThresholdCount,
|
||||||
|
ThresholdWindowMinutes: req.ThresholdWindowMinutes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.settingService.SetStreamTimeoutSettings(c.Request.Context(), settings); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重新获取设置返回
|
||||||
|
updatedSettings, err := h.settingService.GetStreamTimeoutSettings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.StreamTimeoutSettings{
|
||||||
|
Enabled: updatedSettings.Enabled,
|
||||||
|
Action: updatedSettings.Action,
|
||||||
|
TempUnschedMinutes: updatedSettings.TempUnschedMinutes,
|
||||||
|
ThresholdCount: updatedSettings.ThresholdCount,
|
||||||
|
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -76,7 +77,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
|||||||
|
|
||||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||||
if req.VerifyCode == "" {
|
if req.VerifyCode == "" {
|
||||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -105,7 +106,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Turnstile 验证
|
// Turnstile 验证
|
||||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -132,7 +133,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Turnstile 验证
|
// Turnstile 验证
|
||||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
|
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,12 @@ type SystemSettings struct {
|
|||||||
// Identity patch configuration (Claude -> Gemini)
|
// Identity patch configuration (Claude -> Gemini)
|
||||||
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
||||||
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
||||||
|
|
||||||
|
// Ops monitoring (vNext)
|
||||||
|
OpsMonitoringEnabled bool `json:"ops_monitoring_enabled"`
|
||||||
|
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
|
||||||
|
OpsQueryModeDefault string `json:"ops_query_mode_default"`
|
||||||
|
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PublicSettings struct {
|
type PublicSettings struct {
|
||||||
@@ -60,3 +66,12 @@ type PublicSettings struct {
|
|||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||||
|
type StreamTimeoutSettings struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
TempUnschedMinutes int `json:"temp_unsched_minutes"`
|
||||||
|
ThresholdCount int `json:"threshold_count"`
|
||||||
|
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -89,6 +89,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
|
||||||
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
@@ -97,8 +102,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
reqModel := parsedReq.Model
|
reqModel := parsedReq.Model
|
||||||
reqStream := parsedReq.Stream
|
reqStream := parsedReq.Stream
|
||||||
|
|
||||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
SetClaudeCodeClientContext(c, body)
|
|
||||||
|
|
||||||
// 验证 model 必填
|
// 验证 model 必填
|
||||||
if reqModel == "" {
|
if reqModel == "" {
|
||||||
@@ -112,15 +116,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
||||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
// 获取 User-Agent
|
|
||||||
userAgent := c.Request.UserAgent()
|
|
||||||
|
|
||||||
// 获取客户端 IP
|
|
||||||
clientIP := ip.GetClientIP(c)
|
|
||||||
|
|
||||||
// 0. 检查wait队列是否已满
|
// 0. 检查wait队列是否已满
|
||||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||||
|
waitCounted := false
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment wait count failed: %v", err)
|
log.Printf("Increment wait count failed: %v", err)
|
||||||
// On error, allow request to proceed
|
// On error, allow request to proceed
|
||||||
@@ -128,8 +127,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 确保在函数退出时减少wait计数
|
if err == nil && canWait {
|
||||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
waitCounted = true
|
||||||
|
}
|
||||||
|
// Ensure we decrement if we exit before acquiring the user slot.
|
||||||
|
defer func() {
|
||||||
|
if waitCounted {
|
||||||
|
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 1. 首先获取用户并发槽位
|
// 1. 首先获取用户并发槽位
|
||||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||||
@@ -138,6 +144,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// User slot acquired: no longer waiting in the queue.
|
||||||
|
if waitCounted {
|
||||||
|
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||||
|
waitCounted = false
|
||||||
|
}
|
||||||
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
|
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
|
||||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||||
if userReleaseFunc != nil {
|
if userReleaseFunc != nil {
|
||||||
@@ -184,6 +195,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
@@ -200,12 +212,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 3. 获取账号并发槽位
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
var accountWaitRelease func()
|
|
||||||
if !selection.Acquired {
|
if !selection.Acquired {
|
||||||
if selection.WaitPlan == nil {
|
if selection.WaitPlan == nil {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
accountWaitCounted := false
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment account wait count failed: %v", err)
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
@@ -213,12 +225,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
// Only set release function if increment succeeded
|
if err == nil && canWait {
|
||||||
accountWaitRelease = func() {
|
accountWaitCounted = true
|
||||||
|
}
|
||||||
|
// Ensure the wait counter is decremented if we exit before acquiring the slot.
|
||||||
|
defer func() {
|
||||||
|
if accountWaitCounted {
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
c,
|
c,
|
||||||
@@ -229,20 +245,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
&streamStarted,
|
&streamStarted,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if accountWaitRelease != nil {
|
|
||||||
accountWaitRelease()
|
|
||||||
}
|
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Slot acquired: no longer waiting in queue.
|
||||||
|
if accountWaitCounted {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
|
}
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
|
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
@@ -254,19 +271,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
if accountWaitRelease != nil {
|
|
||||||
accountWaitRelease()
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
continue
|
continue
|
||||||
@@ -276,8 +289,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -287,7 +304,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
IPAddress: cip,
|
IPAddress: clientIP,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -313,6 +330,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
@@ -329,12 +347,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 3. 获取账号并发槽位
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
var accountWaitRelease func()
|
|
||||||
if !selection.Acquired {
|
if !selection.Acquired {
|
||||||
if selection.WaitPlan == nil {
|
if selection.WaitPlan == nil {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
accountWaitCounted := false
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment account wait count failed: %v", err)
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
@@ -342,12 +360,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
// Only set release function if increment succeeded
|
if err == nil && canWait {
|
||||||
accountWaitRelease = func() {
|
accountWaitCounted = true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if accountWaitCounted {
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
c,
|
c,
|
||||||
@@ -358,20 +379,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
&streamStarted,
|
&streamStarted,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if accountWaitRelease != nil {
|
|
||||||
accountWaitRelease()
|
|
||||||
}
|
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if accountWaitCounted {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
|
}
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
|
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
@@ -383,19 +404,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
if accountWaitRelease != nil {
|
|
||||||
accountWaitRelease()
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverStatus = failoverErr.StatusCode
|
|
||||||
switchCount++
|
switchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
continue
|
continue
|
||||||
@@ -405,8 +422,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -416,7 +437,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
IPAddress: cip,
|
IPAddress: clientIP,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -686,21 +707,22 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
|
||||||
SetClaudeCodeClientContext(c, body)
|
|
||||||
|
|
||||||
// 验证 model 必填
|
// 验证 model 必填
|
||||||
if parsedReq.Model == "" {
|
if parsedReq.Model == "" {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
|
||||||
|
|
||||||
// 获取订阅信息(可能为nil)
|
// 获取订阅信息(可能为nil)
|
||||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
@@ -721,6 +743,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
// 转发请求(不记录使用量)
|
// 转发请求(不记录使用量)
|
||||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
|
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
|
||||||
|
|||||||
@@ -162,28 +162,32 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setOpsRequestContext(c, modelName, stream, body)
|
||||||
|
|
||||||
// Get subscription (may be nil)
|
// Get subscription (may be nil)
|
||||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
// 获取 User-Agent
|
|
||||||
userAgent := c.Request.UserAgent()
|
|
||||||
|
|
||||||
// 获取客户端 IP
|
|
||||||
clientIP := ip.GetClientIP(c)
|
|
||||||
|
|
||||||
// For Gemini native API, do not send Claude-style ping frames.
|
// For Gemini native API, do not send Claude-style ping frames.
|
||||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
||||||
|
|
||||||
// 0) wait queue check
|
// 0) wait queue check
|
||||||
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
|
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
|
||||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||||
|
waitCounted := false
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment wait count failed: %v", err)
|
log.Printf("Increment wait count failed: %v", err)
|
||||||
} else if !canWait {
|
} else if !canWait {
|
||||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
if err == nil && canWait {
|
||||||
|
waitCounted = true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if waitCounted {
|
||||||
|
geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 1) user concurrency slot
|
// 1) user concurrency slot
|
||||||
streamStarted := false
|
streamStarted := false
|
||||||
@@ -192,6 +196,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if waitCounted {
|
||||||
|
geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||||
|
waitCounted = false
|
||||||
|
}
|
||||||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||||
if userReleaseFunc != nil {
|
if userReleaseFunc != nil {
|
||||||
@@ -207,10 +215,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
|
|
||||||
// 3) select account (sticky session based on request body)
|
// 3) select account (sticky session based on request body)
|
||||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||||
|
|
||||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
|
||||||
SetClaudeCodeClientContext(c, body)
|
|
||||||
|
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||||
sessionKey := sessionHash
|
sessionKey := sessionHash
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
@@ -232,15 +236,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
// 4) account concurrency slot
|
// 4) account concurrency slot
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
var accountWaitRelease func()
|
|
||||||
if !selection.Acquired {
|
if !selection.Acquired {
|
||||||
if selection.WaitPlan == nil {
|
if selection.WaitPlan == nil {
|
||||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
accountWaitCounted := false
|
||||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment account wait count failed: %v", err)
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
@@ -248,12 +253,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
// Only set release function if increment succeeded
|
if err == nil && canWait {
|
||||||
accountWaitRelease = func() {
|
accountWaitCounted = true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if accountWaitCounted {
|
||||||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
||||||
c,
|
c,
|
||||||
@@ -264,19 +272,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
&streamStarted,
|
&streamStarted,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if accountWaitRelease != nil {
|
|
||||||
accountWaitRelease()
|
|
||||||
}
|
|
||||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if accountWaitCounted {
|
||||||
|
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
|
}
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
|
|
||||||
|
|
||||||
// 5) forward (根据平台分流)
|
// 5) forward (根据平台分流)
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
@@ -288,9 +296,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
if accountWaitRelease != nil {
|
|
||||||
accountWaitRelease()
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
@@ -310,8 +315,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// 6) record usage async
|
// 6) record usage async
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
|
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
@@ -321,7 +330,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
IPAddress: cip,
|
IPAddress: ip,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type AdminHandlers struct {
|
|||||||
Redeem *admin.RedeemHandler
|
Redeem *admin.RedeemHandler
|
||||||
Promo *admin.PromoHandler
|
Promo *admin.PromoHandler
|
||||||
Setting *admin.SettingHandler
|
Setting *admin.SettingHandler
|
||||||
|
Ops *admin.OpsHandler
|
||||||
System *admin.SystemHandler
|
System *admin.SystemHandler
|
||||||
Subscription *admin.SubscriptionHandler
|
Subscription *admin.SubscriptionHandler
|
||||||
Usage *admin.UsageHandler
|
Usage *admin.UsageHandler
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -76,6 +77,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
// Parse request body to map for potential modification
|
// Parse request body to map for potential modification
|
||||||
var reqBody map[string]any
|
var reqBody map[string]any
|
||||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||||
@@ -93,19 +96,41 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// For non-Codex CLI requests, set default instructions
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
|
||||||
// 获取客户端 IP
|
|
||||||
clientIP := ip.GetClientIP(c)
|
|
||||||
|
|
||||||
if !openai.IsCodexCLIRequest(userAgent) {
|
if !openai.IsCodexCLIRequest(userAgent) {
|
||||||
reqBody["instructions"] = openai.DefaultInstructions
|
existingInstructions, _ := reqBody["instructions"].(string)
|
||||||
// Re-serialize body
|
if strings.TrimSpace(existingInstructions) == "" {
|
||||||
body, err = json.Marshal(reqBody)
|
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
||||||
if err != nil {
|
reqBody["instructions"] = instructions
|
||||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
// Re-serialize body
|
||||||
return
|
body, err = json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
|
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||||
|
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||||
|
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||||
|
if service.HasFunctionCallOutput(reqBody) {
|
||||||
|
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||||
|
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||||
|
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||||
|
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,6 +143,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 0. Check if wait queue is full
|
// 0. Check if wait queue is full
|
||||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||||
|
waitCounted := false
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment wait count failed: %v", err)
|
log.Printf("Increment wait count failed: %v", err)
|
||||||
// On error, allow request to proceed
|
// On error, allow request to proceed
|
||||||
@@ -125,8 +151,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Ensure wait count is decremented when function exits
|
if err == nil && canWait {
|
||||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
waitCounted = true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if waitCounted {
|
||||||
|
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 1. First acquire user concurrency slot
|
// 1. First acquire user concurrency slot
|
||||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||||
@@ -135,6 +167,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// User slot acquired: no longer waiting.
|
||||||
|
if waitCounted {
|
||||||
|
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||||
|
waitCounted = false
|
||||||
|
}
|
||||||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||||
if userReleaseFunc != nil {
|
if userReleaseFunc != nil {
|
||||||
@@ -172,15 +209,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||||
|
setOpsSelectedAccount(c, account.ID)
|
||||||
|
|
||||||
// 3. Acquire account concurrency slot
|
// 3. Acquire account concurrency slot
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
var accountWaitRelease func()
|
|
||||||
if !selection.Acquired {
|
if !selection.Acquired {
|
||||||
if selection.WaitPlan == nil {
|
if selection.WaitPlan == nil {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
accountWaitCounted := false
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment account wait count failed: %v", err)
|
log.Printf("Increment account wait count failed: %v", err)
|
||||||
@@ -188,12 +226,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
// Only set release function if increment succeeded
|
if err == nil && canWait {
|
||||||
accountWaitRelease = func() {
|
accountWaitCounted = true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if accountWaitCounted {
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
c,
|
c,
|
||||||
@@ -204,29 +245,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
&streamStarted,
|
&streamStarted,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if accountWaitRelease != nil {
|
|
||||||
accountWaitRelease()
|
|
||||||
}
|
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if accountWaitCounted {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
|
}
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
|
|
||||||
|
|
||||||
// Forward request
|
// Forward request
|
||||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
if accountWaitRelease != nil {
|
|
||||||
accountWaitRelease()
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
@@ -246,8 +284,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
|
||||||
// Async record usage
|
// Async record usage
|
||||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) {
|
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
@@ -257,7 +299,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
Account: usedAccount,
|
Account: usedAccount,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
IPAddress: cip,
|
IPAddress: ip,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Printf("Record usage failed: %v", err)
|
log.Printf("Record usage failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
965
backend/internal/handler/ops_error_logger.go
Normal file
965
backend/internal/handler/ops_error_logger.go
Normal file
@@ -0,0 +1,965 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"runtime"
|
||||||
|
"runtime/debug"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
opsModelKey = "ops_model"
|
||||||
|
opsStreamKey = "ops_stream"
|
||||||
|
opsRequestBodyKey = "ops_request_body"
|
||||||
|
opsAccountIDKey = "ops_account_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
opsErrorLogTimeout = 5 * time.Second
|
||||||
|
opsErrorLogDrainTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
opsErrorLogMinWorkerCount = 4
|
||||||
|
opsErrorLogMaxWorkerCount = 32
|
||||||
|
|
||||||
|
opsErrorLogQueueSizePerWorker = 128
|
||||||
|
opsErrorLogMinQueueSize = 256
|
||||||
|
opsErrorLogMaxQueueSize = 8192
|
||||||
|
)
|
||||||
|
|
||||||
|
type opsErrorLogJob struct {
|
||||||
|
ops *service.OpsService
|
||||||
|
entry *service.OpsInsertErrorLogInput
|
||||||
|
requestBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
opsErrorLogOnce sync.Once
|
||||||
|
opsErrorLogQueue chan opsErrorLogJob
|
||||||
|
|
||||||
|
opsErrorLogStopOnce sync.Once
|
||||||
|
opsErrorLogWorkersWg sync.WaitGroup
|
||||||
|
opsErrorLogMu sync.RWMutex
|
||||||
|
opsErrorLogStopping bool
|
||||||
|
opsErrorLogQueueLen atomic.Int64
|
||||||
|
opsErrorLogEnqueued atomic.Int64
|
||||||
|
opsErrorLogDropped atomic.Int64
|
||||||
|
opsErrorLogProcessed atomic.Int64
|
||||||
|
|
||||||
|
opsErrorLogLastDropLogAt atomic.Int64
|
||||||
|
|
||||||
|
opsErrorLogShutdownCh = make(chan struct{})
|
||||||
|
opsErrorLogShutdownOnce sync.Once
|
||||||
|
opsErrorLogDrained atomic.Bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func startOpsErrorLogWorkers() {
|
||||||
|
opsErrorLogMu.Lock()
|
||||||
|
defer opsErrorLogMu.Unlock()
|
||||||
|
|
||||||
|
if opsErrorLogStopping {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
workerCount, queueSize := opsErrorLogConfig()
|
||||||
|
opsErrorLogQueue = make(chan opsErrorLogJob, queueSize)
|
||||||
|
opsErrorLogQueueLen.Store(0)
|
||||||
|
|
||||||
|
opsErrorLogWorkersWg.Add(workerCount)
|
||||||
|
for i := 0; i < workerCount; i++ {
|
||||||
|
go func() {
|
||||||
|
defer opsErrorLogWorkersWg.Done()
|
||||||
|
for job := range opsErrorLogQueue {
|
||||||
|
opsErrorLogQueueLen.Add(-1)
|
||||||
|
if job.ops == nil || job.entry == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||||
|
_ = job.ops.RecordError(ctx, job.entry, job.requestBody)
|
||||||
|
cancel()
|
||||||
|
opsErrorLogProcessed.Add(1)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) {
|
||||||
|
if ops == nil || entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-opsErrorLogShutdownCh:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
opsErrorLogMu.RLock()
|
||||||
|
stopping := opsErrorLogStopping
|
||||||
|
opsErrorLogMu.RUnlock()
|
||||||
|
if stopping {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opsErrorLogOnce.Do(startOpsErrorLogWorkers)
|
||||||
|
|
||||||
|
opsErrorLogMu.RLock()
|
||||||
|
defer opsErrorLogMu.RUnlock()
|
||||||
|
if opsErrorLogStopping || opsErrorLogQueue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}:
|
||||||
|
opsErrorLogQueueLen.Add(1)
|
||||||
|
opsErrorLogEnqueued.Add(1)
|
||||||
|
default:
|
||||||
|
// Queue is full; drop to avoid blocking request handling.
|
||||||
|
opsErrorLogDropped.Add(1)
|
||||||
|
maybeLogOpsErrorLogDrop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func StopOpsErrorLogWorkers() bool {
|
||||||
|
opsErrorLogStopOnce.Do(func() {
|
||||||
|
opsErrorLogShutdownOnce.Do(func() {
|
||||||
|
close(opsErrorLogShutdownCh)
|
||||||
|
})
|
||||||
|
opsErrorLogDrained.Store(stopOpsErrorLogWorkers())
|
||||||
|
})
|
||||||
|
return opsErrorLogDrained.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopOpsErrorLogWorkers() bool {
|
||||||
|
opsErrorLogMu.Lock()
|
||||||
|
opsErrorLogStopping = true
|
||||||
|
ch := opsErrorLogQueue
|
||||||
|
if ch != nil {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
opsErrorLogQueue = nil
|
||||||
|
opsErrorLogMu.Unlock()
|
||||||
|
|
||||||
|
if ch == nil {
|
||||||
|
opsErrorLogQueueLen.Store(0)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
opsErrorLogWorkersWg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
opsErrorLogQueueLen.Store(0)
|
||||||
|
return true
|
||||||
|
case <-time.After(opsErrorLogDrainTimeout):
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpsErrorLogQueueLength() int64 {
|
||||||
|
return opsErrorLogQueueLen.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpsErrorLogQueueCapacity() int {
|
||||||
|
opsErrorLogMu.RLock()
|
||||||
|
ch := opsErrorLogQueue
|
||||||
|
opsErrorLogMu.RUnlock()
|
||||||
|
if ch == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return cap(ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpsErrorLogDroppedTotal() int64 {
|
||||||
|
return opsErrorLogDropped.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpsErrorLogEnqueuedTotal() int64 {
|
||||||
|
return opsErrorLogEnqueued.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpsErrorLogProcessedTotal() int64 {
|
||||||
|
return opsErrorLogProcessed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func maybeLogOpsErrorLogDrop() {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
|
||||||
|
for {
|
||||||
|
last := opsErrorLogLastDropLogAt.Load()
|
||||||
|
if last != 0 && now-last < 60 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if opsErrorLogLastDropLogAt.CompareAndSwap(last, now) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
queued := opsErrorLogQueueLen.Load()
|
||||||
|
queueCap := OpsErrorLogQueueCapacity()
|
||||||
|
|
||||||
|
log.Printf(
|
||||||
|
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)",
|
||||||
|
queued,
|
||||||
|
queueCap,
|
||||||
|
opsErrorLogEnqueued.Load(),
|
||||||
|
opsErrorLogDropped.Load(),
|
||||||
|
opsErrorLogProcessed.Load(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsErrorLogConfig() (workerCount int, queueSize int) {
|
||||||
|
workerCount = runtime.GOMAXPROCS(0) * 2
|
||||||
|
if workerCount < opsErrorLogMinWorkerCount {
|
||||||
|
workerCount = opsErrorLogMinWorkerCount
|
||||||
|
}
|
||||||
|
if workerCount > opsErrorLogMaxWorkerCount {
|
||||||
|
workerCount = opsErrorLogMaxWorkerCount
|
||||||
|
}
|
||||||
|
|
||||||
|
queueSize = workerCount * opsErrorLogQueueSizePerWorker
|
||||||
|
if queueSize < opsErrorLogMinQueueSize {
|
||||||
|
queueSize = opsErrorLogMinQueueSize
|
||||||
|
}
|
||||||
|
if queueSize > opsErrorLogMaxQueueSize {
|
||||||
|
queueSize = opsErrorLogMaxQueueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
return workerCount, queueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody []byte) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(opsModelKey, model)
|
||||||
|
c.Set(opsStreamKey, stream)
|
||||||
|
if len(requestBody) > 0 {
|
||||||
|
c.Set(opsRequestBodyKey, requestBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setOpsSelectedAccount(c *gin.Context, accountID int64) {
|
||||||
|
if c == nil || accountID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(opsAccountIDKey, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
type opsCaptureWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
limit int
|
||||||
|
buf bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
|
||||||
|
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
|
||||||
|
remaining := w.limit - w.buf.Len()
|
||||||
|
if len(b) > remaining {
|
||||||
|
_, _ = w.buf.Write(b[:remaining])
|
||||||
|
} else {
|
||||||
|
_, _ = w.buf.Write(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return w.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *opsCaptureWriter) WriteString(s string) (int, error) {
|
||||||
|
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
|
||||||
|
remaining := w.limit - w.buf.Len()
|
||||||
|
if len(s) > remaining {
|
||||||
|
_, _ = w.buf.WriteString(s[:remaining])
|
||||||
|
} else {
|
||||||
|
_, _ = w.buf.WriteString(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return w.ResponseWriter.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpsErrorLoggerMiddleware records error responses (status >= 400) into ops_error_logs.
|
||||||
|
//
|
||||||
|
// Notes:
|
||||||
|
// - It buffers response bodies only when status >= 400 to avoid overhead for successful traffic.
|
||||||
|
// - Streaming errors after the response has started (SSE) may still need explicit logging.
|
||||||
|
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
if ops == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ops.IsMonitoringEnabled(c.Request.Context()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status := c.Writer.Status()
|
||||||
|
if status < 400 {
|
||||||
|
// Even when the client request succeeds, we still want to persist upstream error attempts
|
||||||
|
// (retries/failover) so ops can observe upstream instability that gets "covered" by retries.
|
||||||
|
var events []*service.OpsUpstreamErrorEvent
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
|
||||||
|
if arr, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(arr) > 0 {
|
||||||
|
events = arr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Also accept single upstream fields set by gateway services (rare for successful requests).
|
||||||
|
hasUpstreamContext := len(events) > 0
|
||||||
|
if !hasUpstreamContext {
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case int:
|
||||||
|
hasUpstreamContext = t > 0
|
||||||
|
case int64:
|
||||||
|
hasUpstreamContext = t > 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasUpstreamContext {
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamErrorMessageKey); ok {
|
||||||
|
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||||
|
hasUpstreamContext = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasUpstreamContext {
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamErrorDetailKey); ok {
|
||||||
|
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||||
|
hasUpstreamContext = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasUpstreamContext {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||||
|
|
||||||
|
model, _ := c.Get(opsModelKey)
|
||||||
|
streamV, _ := c.Get(opsStreamKey)
|
||||||
|
accountIDV, _ := c.Get(opsAccountIDKey)
|
||||||
|
|
||||||
|
var modelName string
|
||||||
|
if s, ok := model.(string); ok {
|
||||||
|
modelName = s
|
||||||
|
}
|
||||||
|
stream := false
|
||||||
|
if b, ok := streamV.(bool); ok {
|
||||||
|
stream = b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer showing the account that experienced the upstream error (if we have events),
|
||||||
|
// otherwise fall back to the final selected account (best-effort).
|
||||||
|
var accountID *int64
|
||||||
|
if len(events) > 0 {
|
||||||
|
if last := events[len(events)-1]; last != nil && last.AccountID > 0 {
|
||||||
|
v := last.AccountID
|
||||||
|
accountID = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if accountID == nil {
|
||||||
|
if v, ok := accountIDV.(int64); ok && v > 0 {
|
||||||
|
accountID = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackPlatform := guessPlatformFromPath(c.Request.URL.Path)
|
||||||
|
platform := resolveOpsPlatform(apiKey, fallbackPlatform)
|
||||||
|
|
||||||
|
requestID := c.Writer.Header().Get("X-Request-Id")
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = c.Writer.Header().Get("x-request-id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Best-effort backfill single upstream fields from the last event (if present).
|
||||||
|
var upstreamStatusCode *int
|
||||||
|
var upstreamErrorMessage *string
|
||||||
|
var upstreamErrorDetail *string
|
||||||
|
if len(events) > 0 {
|
||||||
|
last := events[len(events)-1]
|
||||||
|
if last != nil {
|
||||||
|
if last.UpstreamStatusCode > 0 {
|
||||||
|
code := last.UpstreamStatusCode
|
||||||
|
upstreamStatusCode = &code
|
||||||
|
}
|
||||||
|
if msg := strings.TrimSpace(last.Message); msg != "" {
|
||||||
|
upstreamErrorMessage = &msg
|
||||||
|
}
|
||||||
|
if detail := strings.TrimSpace(last.Detail); detail != "" {
|
||||||
|
upstreamErrorDetail = &detail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if upstreamStatusCode == nil {
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case int:
|
||||||
|
if t > 0 {
|
||||||
|
code := t
|
||||||
|
upstreamStatusCode = &code
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if t > 0 {
|
||||||
|
code := int(t)
|
||||||
|
upstreamStatusCode = &code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if upstreamErrorMessage == nil {
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamErrorMessageKey); ok {
|
||||||
|
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||||
|
msg := strings.TrimSpace(s)
|
||||||
|
upstreamErrorMessage = &msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if upstreamErrorDetail == nil {
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamErrorDetailKey); ok {
|
||||||
|
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||||
|
detail := strings.TrimSpace(s)
|
||||||
|
upstreamErrorDetail = &detail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we still have nothing meaningful, skip.
|
||||||
|
if upstreamStatusCode == nil && upstreamErrorMessage == nil && upstreamErrorDetail == nil && len(events) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
effectiveUpstreamStatus := 0
|
||||||
|
if upstreamStatusCode != nil {
|
||||||
|
effectiveUpstreamStatus = *upstreamStatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
recoveredMsg := "Recovered upstream error"
|
||||||
|
if effectiveUpstreamStatus > 0 {
|
||||||
|
recoveredMsg += " " + strconvItoa(effectiveUpstreamStatus)
|
||||||
|
}
|
||||||
|
if upstreamErrorMessage != nil && strings.TrimSpace(*upstreamErrorMessage) != "" {
|
||||||
|
recoveredMsg += ": " + strings.TrimSpace(*upstreamErrorMessage)
|
||||||
|
}
|
||||||
|
recoveredMsg = truncateString(recoveredMsg, 2048)
|
||||||
|
|
||||||
|
entry := &service.OpsInsertErrorLogInput{
|
||||||
|
RequestID: requestID,
|
||||||
|
ClientRequestID: clientRequestID,
|
||||||
|
|
||||||
|
AccountID: accountID,
|
||||||
|
Platform: platform,
|
||||||
|
Model: modelName,
|
||||||
|
RequestPath: func() string {
|
||||||
|
if c.Request != nil && c.Request.URL != nil {
|
||||||
|
return c.Request.URL.Path
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
Stream: stream,
|
||||||
|
UserAgent: c.GetHeader("User-Agent"),
|
||||||
|
|
||||||
|
ErrorPhase: "upstream",
|
||||||
|
ErrorType: "upstream_error",
|
||||||
|
// Severity/retryability should reflect the upstream failure, not the final client status (200).
|
||||||
|
Severity: classifyOpsSeverity("upstream_error", effectiveUpstreamStatus),
|
||||||
|
StatusCode: status,
|
||||||
|
IsBusinessLimited: false,
|
||||||
|
IsCountTokens: isCountTokensRequest(c),
|
||||||
|
|
||||||
|
ErrorMessage: recoveredMsg,
|
||||||
|
ErrorBody: "",
|
||||||
|
|
||||||
|
ErrorSource: "upstream_http",
|
||||||
|
ErrorOwner: "provider",
|
||||||
|
|
||||||
|
UpstreamStatusCode: upstreamStatusCode,
|
||||||
|
UpstreamErrorMessage: upstreamErrorMessage,
|
||||||
|
UpstreamErrorDetail: upstreamErrorDetail,
|
||||||
|
UpstreamErrors: events,
|
||||||
|
|
||||||
|
IsRetryable: classifyOpsIsRetryable("upstream_error", effectiveUpstreamStatus),
|
||||||
|
RetryCount: 0,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiKey != nil {
|
||||||
|
entry.APIKeyID = &apiKey.ID
|
||||||
|
if apiKey.User != nil {
|
||||||
|
entry.UserID = &apiKey.User.ID
|
||||||
|
}
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
entry.GroupID = apiKey.GroupID
|
||||||
|
}
|
||||||
|
// Prefer group platform if present (more stable than inferring from path).
|
||||||
|
if apiKey.Group != nil && apiKey.Group.Platform != "" {
|
||||||
|
entry.Platform = apiKey.Group.Platform
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var clientIP string
|
||||||
|
if ip := strings.TrimSpace(ip.GetClientIP(c)); ip != "" {
|
||||||
|
clientIP = ip
|
||||||
|
entry.ClientIP = &clientIP
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestBody []byte
|
||||||
|
if v, ok := c.Get(opsRequestBodyKey); ok {
|
||||||
|
if b, ok := v.([]byte); ok && len(b) > 0 {
|
||||||
|
requestBody = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
||||||
|
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||||
|
|
||||||
|
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.buf.Bytes()
|
||||||
|
parsed := parseOpsErrorResponse(body)
|
||||||
|
|
||||||
|
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
|
||||||
|
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||||
|
|
||||||
|
model, _ := c.Get(opsModelKey)
|
||||||
|
streamV, _ := c.Get(opsStreamKey)
|
||||||
|
accountIDV, _ := c.Get(opsAccountIDKey)
|
||||||
|
|
||||||
|
var modelName string
|
||||||
|
if s, ok := model.(string); ok {
|
||||||
|
modelName = s
|
||||||
|
}
|
||||||
|
stream := false
|
||||||
|
if b, ok := streamV.(bool); ok {
|
||||||
|
stream = b
|
||||||
|
}
|
||||||
|
var accountID *int64
|
||||||
|
if v, ok := accountIDV.(int64); ok && v > 0 {
|
||||||
|
accountID = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackPlatform := guessPlatformFromPath(c.Request.URL.Path)
|
||||||
|
platform := resolveOpsPlatform(apiKey, fallbackPlatform)
|
||||||
|
|
||||||
|
requestID := c.Writer.Header().Get("X-Request-Id")
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = c.Writer.Header().Get("x-request-id")
|
||||||
|
}
|
||||||
|
|
||||||
|
phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code)
|
||||||
|
isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message)
|
||||||
|
|
||||||
|
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
|
||||||
|
errorSource := classifyOpsErrorSource(phase, parsed.Message)
|
||||||
|
|
||||||
|
entry := &service.OpsInsertErrorLogInput{
|
||||||
|
RequestID: requestID,
|
||||||
|
ClientRequestID: clientRequestID,
|
||||||
|
|
||||||
|
AccountID: accountID,
|
||||||
|
Platform: platform,
|
||||||
|
Model: modelName,
|
||||||
|
RequestPath: func() string {
|
||||||
|
if c.Request != nil && c.Request.URL != nil {
|
||||||
|
return c.Request.URL.Path
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
Stream: stream,
|
||||||
|
UserAgent: c.GetHeader("User-Agent"),
|
||||||
|
|
||||||
|
ErrorPhase: phase,
|
||||||
|
ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code),
|
||||||
|
Severity: classifyOpsSeverity(parsed.ErrorType, status),
|
||||||
|
StatusCode: status,
|
||||||
|
IsBusinessLimited: isBusinessLimited,
|
||||||
|
IsCountTokens: isCountTokensRequest(c),
|
||||||
|
|
||||||
|
ErrorMessage: parsed.Message,
|
||||||
|
// Keep the full captured error body (capture is already capped at 64KB) so the
|
||||||
|
// service layer can sanitize JSON before truncating for storage.
|
||||||
|
ErrorBody: string(body),
|
||||||
|
ErrorSource: errorSource,
|
||||||
|
ErrorOwner: errorOwner,
|
||||||
|
|
||||||
|
IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status),
|
||||||
|
RetryCount: 0,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture upstream error context set by gateway services (if present).
|
||||||
|
// This does NOT affect the client response; it enriches Ops troubleshooting data.
|
||||||
|
{
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case int:
|
||||||
|
if t > 0 {
|
||||||
|
code := t
|
||||||
|
entry.UpstreamStatusCode = &code
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if t > 0 {
|
||||||
|
code := int(t)
|
||||||
|
entry.UpstreamStatusCode = &code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamErrorMessageKey); ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
if msg := strings.TrimSpace(s); msg != "" {
|
||||||
|
entry.UpstreamErrorMessage = &msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamErrorDetailKey); ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
if detail := strings.TrimSpace(s); detail != "" {
|
||||||
|
entry.UpstreamErrorDetail = &detail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
|
||||||
|
if events, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(events) > 0 {
|
||||||
|
entry.UpstreamErrors = events
|
||||||
|
// Best-effort backfill the single upstream fields from the last event when missing.
|
||||||
|
last := events[len(events)-1]
|
||||||
|
if last != nil {
|
||||||
|
if entry.UpstreamStatusCode == nil && last.UpstreamStatusCode > 0 {
|
||||||
|
code := last.UpstreamStatusCode
|
||||||
|
entry.UpstreamStatusCode = &code
|
||||||
|
}
|
||||||
|
if entry.UpstreamErrorMessage == nil && strings.TrimSpace(last.Message) != "" {
|
||||||
|
msg := strings.TrimSpace(last.Message)
|
||||||
|
entry.UpstreamErrorMessage = &msg
|
||||||
|
}
|
||||||
|
if entry.UpstreamErrorDetail == nil && strings.TrimSpace(last.Detail) != "" {
|
||||||
|
detail := strings.TrimSpace(last.Detail)
|
||||||
|
entry.UpstreamErrorDetail = &detail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiKey != nil {
|
||||||
|
entry.APIKeyID = &apiKey.ID
|
||||||
|
if apiKey.User != nil {
|
||||||
|
entry.UserID = &apiKey.User.ID
|
||||||
|
}
|
||||||
|
if apiKey.GroupID != nil {
|
||||||
|
entry.GroupID = apiKey.GroupID
|
||||||
|
}
|
||||||
|
// Prefer group platform if present (more stable than inferring from path).
|
||||||
|
if apiKey.Group != nil && apiKey.Group.Platform != "" {
|
||||||
|
entry.Platform = apiKey.Group.Platform
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var clientIP string
|
||||||
|
if ip := strings.TrimSpace(ip.GetClientIP(c)); ip != "" {
|
||||||
|
clientIP = ip
|
||||||
|
entry.ClientIP = &clientIP
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestBody []byte
|
||||||
|
if v, ok := c.Get(opsRequestBodyKey); ok {
|
||||||
|
if b, ok := v.([]byte); ok && len(b) > 0 {
|
||||||
|
requestBody = b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
|
||||||
|
// Do NOT store Authorization/Cookie/etc.
|
||||||
|
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||||
|
|
||||||
|
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var opsRetryRequestHeaderAllowlist = []string{
|
||||||
|
"anthropic-beta",
|
||||||
|
"anthropic-version",
|
||||||
|
}
|
||||||
|
|
||||||
|
// isCountTokensRequest checks if the request is a count_tokens request
|
||||||
|
func isCountTokensRequest(c *gin.Context) bool {
|
||||||
|
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(c.Request.URL.Path, "/count_tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpsRetryRequestHeaders(c *gin.Context) *string {
|
||||||
|
if c == nil || c.Request == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := make(map[string]string, 4)
|
||||||
|
for _, key := range opsRetryRequestHeaderAllowlist {
|
||||||
|
v := strings.TrimSpace(c.GetHeader(key))
|
||||||
|
if v == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Keep headers small even if a client sends something unexpected.
|
||||||
|
headers[key] = truncateString(v, 512)
|
||||||
|
}
|
||||||
|
if len(headers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := json.Marshal(headers)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s := string(raw)
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
type parsedOpsError struct {
|
||||||
|
ErrorType string
|
||||||
|
Message string
|
||||||
|
Code string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpsErrorResponse(body []byte) parsedOpsError {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return parsedOpsError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fast path: attempt to decode into a generic map.
|
||||||
|
var m map[string]any
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return parsedOpsError{Message: truncateString(string(body), 1024)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claude/OpenAI-style gateway error: { type:"error", error:{ type, message } }
|
||||||
|
if errObj, ok := m["error"].(map[string]any); ok {
|
||||||
|
t, _ := errObj["type"].(string)
|
||||||
|
msg, _ := errObj["message"].(string)
|
||||||
|
// Gemini googleError also uses "error": { code, message, status }
|
||||||
|
if msg == "" {
|
||||||
|
if v, ok := errObj["message"]; ok {
|
||||||
|
msg, _ = v.(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if t == "" {
|
||||||
|
// Gemini error does not have "type" field.
|
||||||
|
t = "api_error"
|
||||||
|
}
|
||||||
|
// For gemini error, capture numeric code as string for business-limited mapping if needed.
|
||||||
|
var code string
|
||||||
|
if v, ok := errObj["code"]; ok {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
code = strconvItoa(int(n))
|
||||||
|
case int:
|
||||||
|
code = strconvItoa(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parsedOpsError{ErrorType: t, Message: msg, Code: code}
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuth-style: { code:"INSUFFICIENT_BALANCE", message:"..." }
|
||||||
|
code, _ := m["code"].(string)
|
||||||
|
msg, _ := m["message"].(string)
|
||||||
|
if code != "" || msg != "" {
|
||||||
|
return parsedOpsError{ErrorType: "api_error", Message: msg, Code: code}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedOpsError{Message: truncateString(string(body), 1024)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveOpsPlatform(apiKey *service.APIKey, fallback string) string {
|
||||||
|
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform != "" {
|
||||||
|
return apiKey.Group.Platform
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func guessPlatformFromPath(path string) string {
|
||||||
|
p := strings.ToLower(path)
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(p, "/antigravity/"):
|
||||||
|
return service.PlatformAntigravity
|
||||||
|
case strings.HasPrefix(p, "/v1beta/"):
|
||||||
|
return service.PlatformGemini
|
||||||
|
case strings.Contains(p, "/responses"):
|
||||||
|
return service.PlatformOpenAI
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOpsErrorType(errType string, code string) string {
|
||||||
|
if errType != "" {
|
||||||
|
return errType
|
||||||
|
}
|
||||||
|
switch strings.TrimSpace(code) {
|
||||||
|
case "INSUFFICIENT_BALANCE":
|
||||||
|
return "billing_error"
|
||||||
|
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||||
|
return "subscription_error"
|
||||||
|
default:
|
||||||
|
return "api_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyOpsPhase(errType, message, code string) string {
|
||||||
|
msg := strings.ToLower(message)
|
||||||
|
switch strings.TrimSpace(code) {
|
||||||
|
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||||
|
return "billing"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch errType {
|
||||||
|
case "authentication_error":
|
||||||
|
return "auth"
|
||||||
|
case "billing_error", "subscription_error":
|
||||||
|
return "billing"
|
||||||
|
case "rate_limit_error":
|
||||||
|
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
|
||||||
|
return "concurrency"
|
||||||
|
}
|
||||||
|
return "upstream"
|
||||||
|
case "invalid_request_error":
|
||||||
|
return "response"
|
||||||
|
case "upstream_error", "overloaded_error":
|
||||||
|
return "upstream"
|
||||||
|
case "api_error":
|
||||||
|
if strings.Contains(msg, "no available accounts") {
|
||||||
|
return "scheduling"
|
||||||
|
}
|
||||||
|
return "internal"
|
||||||
|
default:
|
||||||
|
return "internal"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyOpsSeverity(errType string, status int) string {
|
||||||
|
switch errType {
|
||||||
|
case "invalid_request_error", "authentication_error", "billing_error", "subscription_error":
|
||||||
|
return "P3"
|
||||||
|
}
|
||||||
|
if status >= 500 {
|
||||||
|
return "P1"
|
||||||
|
}
|
||||||
|
if status == 429 {
|
||||||
|
return "P1"
|
||||||
|
}
|
||||||
|
if status >= 400 {
|
||||||
|
return "P2"
|
||||||
|
}
|
||||||
|
return "P3"
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||||
|
switch errType {
|
||||||
|
case "authentication_error", "invalid_request_error":
|
||||||
|
return false
|
||||||
|
case "timeout_error":
|
||||||
|
return true
|
||||||
|
case "rate_limit_error":
|
||||||
|
// May be transient (upstream or queue); retry can help.
|
||||||
|
return true
|
||||||
|
case "billing_error", "subscription_error":
|
||||||
|
return false
|
||||||
|
case "upstream_error", "overloaded_error":
|
||||||
|
return statusCode >= 500 || statusCode == 429 || statusCode == 529
|
||||||
|
default:
|
||||||
|
return statusCode >= 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||||
|
switch strings.TrimSpace(code) {
|
||||||
|
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if phase == "billing" || phase == "concurrency" {
|
||||||
|
// SLA/错误率排除“用户级业务限制”
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Avoid treating upstream rate limits as business-limited.
|
||||||
|
if errType == "rate_limit_error" && strings.Contains(strings.ToLower(message), "upstream") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_ = status
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyOpsErrorOwner(phase string, message string) string {
|
||||||
|
switch phase {
|
||||||
|
case "upstream", "network":
|
||||||
|
return "provider"
|
||||||
|
case "billing", "concurrency", "auth", "response":
|
||||||
|
return "client"
|
||||||
|
default:
|
||||||
|
if strings.Contains(strings.ToLower(message), "upstream") {
|
||||||
|
return "provider"
|
||||||
|
}
|
||||||
|
return "sub2api"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyOpsErrorSource(phase string, message string) string {
|
||||||
|
switch phase {
|
||||||
|
case "upstream":
|
||||||
|
return "upstream_http"
|
||||||
|
case "network":
|
||||||
|
return "upstream_network"
|
||||||
|
case "billing":
|
||||||
|
return "billing"
|
||||||
|
case "concurrency":
|
||||||
|
return "concurrency"
|
||||||
|
default:
|
||||||
|
if strings.Contains(strings.ToLower(message), "upstream") {
|
||||||
|
return "upstream_http"
|
||||||
|
}
|
||||||
|
return "internal"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateString(s string, max int) string {
|
||||||
|
if max <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(s) <= max {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
cut := s[:max]
|
||||||
|
// Ensure truncation does not split multi-byte characters.
|
||||||
|
for len(cut) > 0 && !utf8.ValidString(cut) {
|
||||||
|
cut = cut[:len(cut)-1]
|
||||||
|
}
|
||||||
|
return cut
|
||||||
|
}
|
||||||
|
|
||||||
|
func strconvItoa(v int) string {
|
||||||
|
return strconv.Itoa(v)
|
||||||
|
}
|
||||||
@@ -21,6 +21,7 @@ func ProvideAdminHandlers(
|
|||||||
redeemHandler *admin.RedeemHandler,
|
redeemHandler *admin.RedeemHandler,
|
||||||
promoHandler *admin.PromoHandler,
|
promoHandler *admin.PromoHandler,
|
||||||
settingHandler *admin.SettingHandler,
|
settingHandler *admin.SettingHandler,
|
||||||
|
opsHandler *admin.OpsHandler,
|
||||||
systemHandler *admin.SystemHandler,
|
systemHandler *admin.SystemHandler,
|
||||||
subscriptionHandler *admin.SubscriptionHandler,
|
subscriptionHandler *admin.SubscriptionHandler,
|
||||||
usageHandler *admin.UsageHandler,
|
usageHandler *admin.UsageHandler,
|
||||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
|||||||
Redeem: redeemHandler,
|
Redeem: redeemHandler,
|
||||||
Promo: promoHandler,
|
Promo: promoHandler,
|
||||||
Setting: settingHandler,
|
Setting: settingHandler,
|
||||||
|
Ops: opsHandler,
|
||||||
System: systemHandler,
|
System: systemHandler,
|
||||||
Subscription: subscriptionHandler,
|
Subscription: subscriptionHandler,
|
||||||
Usage: usageHandler,
|
Usage: usageHandler,
|
||||||
@@ -109,6 +111,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewRedeemHandler,
|
admin.NewRedeemHandler,
|
||||||
admin.NewPromoHandler,
|
admin.NewPromoHandler,
|
||||||
admin.NewSettingHandler,
|
admin.NewSettingHandler,
|
||||||
|
admin.NewOpsHandler,
|
||||||
ProvideSystemHandler,
|
ProvideSystemHandler,
|
||||||
admin.NewSubscriptionHandler,
|
admin.NewSubscriptionHandler,
|
||||||
admin.NewUsageHandler,
|
admin.NewUsageHandler,
|
||||||
|
|||||||
@@ -1,13 +1,63 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RateLimitFailureMode Redis 故障策略
|
||||||
|
type RateLimitFailureMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
RateLimitFailOpen RateLimitFailureMode = iota
|
||||||
|
RateLimitFailClose
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimitOptions 限流可选配置
|
||||||
|
type RateLimitOptions struct {
|
||||||
|
FailureMode RateLimitFailureMode
|
||||||
|
}
|
||||||
|
|
||||||
|
var rateLimitScript = redis.NewScript(`
|
||||||
|
local current = redis.call('INCR', KEYS[1])
|
||||||
|
local ttl = redis.call('PTTL', KEYS[1])
|
||||||
|
local repaired = 0
|
||||||
|
if current == 1 then
|
||||||
|
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||||
|
elseif ttl == -1 then
|
||||||
|
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||||
|
repaired = 1
|
||||||
|
end
|
||||||
|
return {current, repaired}
|
||||||
|
`)
|
||||||
|
|
||||||
|
// rateLimitRun 允许测试覆写脚本执行逻辑
|
||||||
|
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||||
|
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
if len(values) < 2 {
|
||||||
|
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
|
||||||
|
}
|
||||||
|
count, err := parseInt64(values[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
repaired, err := parseInt64(values[1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
return count, repaired == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
// RateLimiter Redis 速率限制器
|
// RateLimiter Redis 速率限制器
|
||||||
type RateLimiter struct {
|
type RateLimiter struct {
|
||||||
redis *redis.Client
|
redis *redis.Client
|
||||||
@@ -27,34 +77,85 @@ func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
|||||||
// limit: 时间窗口内最大请求数
|
// limit: 时间窗口内最大请求数
|
||||||
// window: 时间窗口
|
// window: 时间窗口
|
||||||
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
||||||
|
return r.LimitWithOptions(key, limit, window, RateLimitOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// LimitWithOptions 返回速率限制中间件(带可选配置)
|
||||||
|
func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc {
|
||||||
|
failureMode := opts.FailureMode
|
||||||
|
if failureMode != RateLimitFailClose {
|
||||||
|
failureMode = RateLimitFailOpen
|
||||||
|
}
|
||||||
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
ip := c.ClientIP()
|
ip := c.ClientIP()
|
||||||
redisKey := r.prefix + key + ":" + ip
|
redisKey := r.prefix + key + ":" + ip
|
||||||
|
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// 使用 INCR 原子操作增加计数
|
windowMillis := windowTTLMillis(window)
|
||||||
count, err := r.redis.Incr(ctx, redisKey).Result()
|
|
||||||
|
// 使用 Lua 脚本原子操作增加计数并设置过期
|
||||||
|
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
|
||||||
|
if failureMode == RateLimitFailClose {
|
||||||
|
abortRateLimit(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
// Redis 错误时放行,避免影响正常服务
|
// Redis 错误时放行,避免影响正常服务
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if repaired {
|
||||||
// 首次访问时设置过期时间
|
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
|
||||||
if count == 1 {
|
|
||||||
r.redis.Expire(ctx, redisKey, window)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 超过限制
|
// 超过限制
|
||||||
if count > int64(limit) {
|
if count > int64(limit) {
|
||||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
abortRateLimit(c)
|
||||||
"error": "rate limit exceeded",
|
|
||||||
"message": "Too many requests, please try again later",
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func windowTTLMillis(window time.Duration) int64 {
|
||||||
|
ttl := window.Milliseconds()
|
||||||
|
if ttl < 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
func abortRateLimit(c *gin.Context) {
|
||||||
|
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||||
|
"error": "rate limit exceeded",
|
||||||
|
"message": "Too many requests, please try again later",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func failureModeLabel(mode RateLimitFailureMode) string {
|
||||||
|
if mode == RateLimitFailClose {
|
||||||
|
return "fail-close"
|
||||||
|
}
|
||||||
|
return "fail-open"
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInt64(value any) (int64, error) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int64:
|
||||||
|
return v, nil
|
||||||
|
case int:
|
||||||
|
return int64(v), nil
|
||||||
|
case string:
|
||||||
|
parsed, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unexpected value type %T", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
114
backend/internal/middleware/rate_limiter_integration_test.go
Normal file
114
backend/internal/middleware/rate_limiter_integration_test.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
const redisImageTag = "redis:8.4-alpine"
|
||||||
|
|
||||||
|
func TestRateLimiterSetsTTLAndDoesNotRefresh(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startRedis(t, ctx)
|
||||||
|
limiter := NewRateLimiter(rdb)
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(limiter.Limit("ttl-test", 10, 2*time.Second))
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
recorder := performRequest(router)
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
redisKey := limiter.prefix + "ttl-test:127.0.0.1"
|
||||||
|
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Greater(t, ttlBefore, time.Duration(0))
|
||||||
|
require.LessOrEqual(t, ttlBefore, 2*time.Second)
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
recorder = performRequest(router)
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Less(t, ttlAfter, ttlBefore)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterFixesMissingTTL(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startRedis(t, ctx)
|
||||||
|
limiter := NewRateLimiter(rdb)
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(limiter.Limit("ttl-missing", 10, 2*time.Second))
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
redisKey := limiter.prefix + "ttl-missing:127.0.0.1"
|
||||||
|
require.NoError(t, rdb.Set(ctx, redisKey, 5, 0).Err())
|
||||||
|
|
||||||
|
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Less(t, ttlBefore, time.Duration(0))
|
||||||
|
|
||||||
|
recorder := performRequest(router)
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Greater(t, ttlAfter, time.Duration(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "127.0.0.1:1234"
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
return recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func startRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
redisContainer, err := tcredis.Run(ctx, redisImageTag)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = redisContainer.Terminate(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
redisHost, err := redisContainer.Host(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||||
|
DB: 0,
|
||||||
|
})
|
||||||
|
require.NoError(t, rdb.Ping(ctx).Err())
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return rdb
|
||||||
|
}
|
||||||
100
backend/internal/middleware/rate_limiter_test.go
Normal file
100
backend/internal/middleware/rate_limiter_test.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWindowTTLMillis(t *testing.T) {
|
||||||
|
require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond))
|
||||||
|
require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond))
|
||||||
|
require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterFailureModes(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "127.0.0.1:1",
|
||||||
|
DialTimeout: 50 * time.Millisecond,
|
||||||
|
ReadTimeout: 50 * time.Millisecond,
|
||||||
|
WriteTimeout: 50 * time.Millisecond,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
limiter := NewRateLimiter(rdb)
|
||||||
|
|
||||||
|
failOpenRouter := gin.New()
|
||||||
|
failOpenRouter.Use(limiter.Limit("test", 1, time.Second))
|
||||||
|
failOpenRouter.GET("/test", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "127.0.0.1:1234"
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
failOpenRouter.ServeHTTP(recorder, req)
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
failCloseRouter := gin.New()
|
||||||
|
failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{
|
||||||
|
FailureMode: RateLimitFailClose,
|
||||||
|
}))
|
||||||
|
failCloseRouter.GET("/test", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "127.0.0.1:1234"
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
failCloseRouter.ServeHTTP(recorder, req)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
originalRun := rateLimitRun
|
||||||
|
counts := []int64{1, 2}
|
||||||
|
callIndex := 0
|
||||||
|
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||||
|
if callIndex >= len(counts) {
|
||||||
|
return counts[len(counts)-1], false, nil
|
||||||
|
}
|
||||||
|
value := counts[callIndex]
|
||||||
|
callIndex++
|
||||||
|
return value, false, nil
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
rateLimitRun = originalRun
|
||||||
|
})
|
||||||
|
|
||||||
|
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(limiter.Limit("test", 1, time.Second))
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "127.0.0.1:1234"
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "127.0.0.1:1234"
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
||||||
|
}
|
||||||
@@ -7,7 +7,14 @@ type Key string
|
|||||||
const (
|
const (
|
||||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||||
ForcePlatform Key = "ctx_force_platform"
|
ForcePlatform Key = "ctx_force_platform"
|
||||||
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
|
|
||||||
|
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
|
||||||
|
ClientRequestID Key = "ctx_client_request_id"
|
||||||
|
|
||||||
|
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
|
||||||
|
RetryCount Key = "ctx_retry_count"
|
||||||
|
|
||||||
|
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
|
||||||
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
||||||
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
||||||
Group Key = "ctx_group"
|
Group Key = "ctx_group"
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ type DashboardStats struct {
|
|||||||
TotalUsers int64 `json:"total_users"`
|
TotalUsers int64 `json:"total_users"`
|
||||||
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
||||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||||
|
// 小时活跃用户数(UTC 当前小时)
|
||||||
|
HourlyActiveUsers int64 `json:"hourly_active_users"`
|
||||||
|
|
||||||
|
// 预聚合新鲜度
|
||||||
|
StatsUpdatedAt string `json:"stats_updated_at"`
|
||||||
|
StatsStale bool `json:"stats_stale"`
|
||||||
|
|
||||||
// API Key 统计
|
// API Key 统计
|
||||||
TotalAPIKeys int64 `json:"total_api_keys"`
|
TotalAPIKeys int64 `json:"total_api_keys"`
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -115,6 +116,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
|||||||
account.ID = created.ID
|
account.ID = created.ID
|
||||||
account.CreatedAt = created.CreatedAt
|
account.CreatedAt = created.CreatedAt
|
||||||
account.UpdatedAt = created.UpdatedAt
|
account.UpdatedAt = created.UpdatedAt
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -341,10 +345,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||||
}
|
}
|
||||||
account.UpdatedAt = updated.UpdatedAt
|
account.UpdatedAt = updated.UpdatedAt
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||||
|
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// 使用事务保证账号与关联分组的删除原子性
|
// 使用事务保证账号与关联分组的删除原子性
|
||||||
tx, err := r.client.Tx(ctx)
|
tx, err := r.client.Tx(ctx)
|
||||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||||
@@ -368,7 +379,12 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
return tx.Commit()
|
if err := tx.Commit(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -455,7 +471,18 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
|||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
SetLastUsedAt(now).
|
SetLastUsedAt(now).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := map[string]any{
|
||||||
|
"last_used": map[string]int64{
|
||||||
|
strconv.FormatInt(id, 10): now.Unix(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||||
@@ -479,7 +506,18 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
|
|||||||
args = append(args, pq.Array(ids))
|
args = append(args, pq.Array(ids))
|
||||||
|
|
||||||
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
|
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
lastUsedPayload := make(map[string]int64, len(updates))
|
||||||
|
for id, ts := range updates {
|
||||||
|
lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix()
|
||||||
|
}
|
||||||
|
payload := map[string]any{"last_used": lastUsedPayload}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
@@ -488,7 +526,13 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
|||||||
SetStatus(service.StatusError).
|
SetStatus(service.StatusError).
|
||||||
SetErrorMessage(errorMsg).
|
SetErrorMessage(errorMsg).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||||
@@ -497,7 +541,14 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
|||||||
SetGroupID(groupID).
|
SetGroupID(groupID).
|
||||||
SetPriority(priority).
|
SetPriority(priority).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||||
@@ -507,7 +558,14 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
|
|||||||
dbaccountgroup.GroupIDEQ(groupID),
|
dbaccountgroup.GroupIDEQ(groupID),
|
||||||
).
|
).
|
||||||
Exec(ctx)
|
Exec(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
|
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
|
||||||
@@ -528,6 +586,10 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||||
|
existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// 使用事务保证删除旧绑定与创建新绑定的原子性
|
// 使用事务保证删除旧绑定与创建新绑定的原子性
|
||||||
tx, err := r.client.Tx(ctx)
|
tx, err := r.client.Tx(ctx)
|
||||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||||
@@ -568,7 +630,13 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
return tx.Commit()
|
if err := tx.Commit(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -672,7 +740,13 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
|||||||
SetRateLimitedAt(now).
|
SetRateLimitedAt(now).
|
||||||
SetRateLimitResetAt(resetAt).
|
SetRateLimitResetAt(resetAt).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
@@ -706,6 +780,9 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -714,7 +791,13 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
|||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
SetOverloadUntil(until).
|
SetOverloadUntil(until).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
@@ -727,7 +810,13 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
|||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
|
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
|
||||||
`, until, reason, id)
|
`, until, reason, id)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||||
@@ -739,7 +828,13 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
|
|||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
`, id)
|
`, id)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||||
@@ -749,7 +844,13 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
|||||||
ClearRateLimitResetAt().
|
ClearRateLimitResetAt().
|
||||||
ClearOverloadUntil().
|
ClearOverloadUntil().
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
@@ -770,6 +871,9 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -792,7 +896,13 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
|||||||
Where(dbaccount.IDEQ(id)).
|
Where(dbaccount.IDEQ(id)).
|
||||||
SetSchedulable(schedulable).
|
SetSchedulable(schedulable).
|
||||||
Save(ctx)
|
Save(ctx)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||||
@@ -813,6 +923,11 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
if rows > 0 {
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -844,6 +959,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
|||||||
if affected == 0 {
|
if affected == 0 {
|
||||||
return service.ErrAccountNotFound
|
return service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -928,6 +1046,12 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
if rows > 0 {
|
||||||
|
payload := map[string]any{"account_ids": ids}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1170,6 +1294,54 @@ func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []
|
|||||||
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
|
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) {
|
||||||
|
entries, err := r.client.AccountGroup.
|
||||||
|
Query().
|
||||||
|
Where(dbaccountgroup.AccountIDEQ(accountID)).
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ids := make([]int64, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
ids = append(ids, entry.GroupID)
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeGroupIDs(a []int64, b []int64) []int64 {
|
||||||
|
seen := make(map[int64]struct{}, len(a)+len(b))
|
||||||
|
out := make([]int64, 0, len(a)+len(b))
|
||||||
|
for _, id := range a {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
for _, id := range b {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSchedulerGroupPayload(groupIDs []int64) map[string]any {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return map[string]any{"group_ids": groupIDs}
|
||||||
|
}
|
||||||
|
|
||||||
func accountEntityToService(m *dbent.Account) *service.Account {
|
func accountEntityToService(m *dbent.Account) *service.Account {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||||
apiKeyRateLimitDuration = 24 * time.Hour
|
apiKeyRateLimitDuration = 24 * time.Hour
|
||||||
|
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||||
)
|
)
|
||||||
|
|
||||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||||
@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
|
|||||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func apiKeyAuthCacheKey(key string) string {
|
||||||
|
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
|
||||||
|
}
|
||||||
|
|
||||||
type apiKeyCache struct {
|
type apiKeyCache struct {
|
||||||
rdb *redis.Client
|
rdb *redis.Client
|
||||||
}
|
}
|
||||||
@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
|
|||||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||||
|
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var entry service.APIKeyAuthCacheEntry
|
||||||
|
if err := json.Unmarshal(val, &entry); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
if entry == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(entry)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
@@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
|
|||||||
return apiKeyEntityToService(m), nil
|
return apiKeyEntityToService(m), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。
|
||||||
// 相比 GetByID,此方法性能更优,因为:
|
// 相比 GetByID,此方法性能更优,因为:
|
||||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
// - 使用 Select() 只查询必要字段,减少数据传输量
|
||||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
// - 适用于删除等只需 key 与用户 ID 的场景
|
||||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
m, err := r.activeQuery().
|
m, err := r.activeQuery().
|
||||||
Where(apikey.IDEQ(id)).
|
Where(apikey.IDEQ(id)).
|
||||||
Select(apikey.FieldUserID).
|
Select(apikey.FieldKey, apikey.FieldUserID).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if dbent.IsNotFound(err) {
|
if dbent.IsNotFound(err) {
|
||||||
return 0, service.ErrAPIKeyNotFound
|
return "", 0, service.ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
return 0, err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
return m.UserID, nil
|
return m.Key, m.UserID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
|||||||
return apiKeyEntityToService(m), nil
|
return apiKeyEntityToService(m), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
m, err := r.activeQuery().
|
||||||
|
Where(apikey.KeyEQ(key)).
|
||||||
|
Select(
|
||||||
|
apikey.FieldID,
|
||||||
|
apikey.FieldUserID,
|
||||||
|
apikey.FieldGroupID,
|
||||||
|
apikey.FieldStatus,
|
||||||
|
apikey.FieldIPWhitelist,
|
||||||
|
apikey.FieldIPBlacklist,
|
||||||
|
).
|
||||||
|
WithUser(func(q *dbent.UserQuery) {
|
||||||
|
q.Select(
|
||||||
|
user.FieldID,
|
||||||
|
user.FieldStatus,
|
||||||
|
user.FieldRole,
|
||||||
|
user.FieldBalance,
|
||||||
|
user.FieldConcurrency,
|
||||||
|
)
|
||||||
|
}).
|
||||||
|
WithGroup(func(q *dbent.GroupQuery) {
|
||||||
|
q.Select(
|
||||||
|
group.FieldID,
|
||||||
|
group.FieldName,
|
||||||
|
group.FieldPlatform,
|
||||||
|
group.FieldStatus,
|
||||||
|
group.FieldSubscriptionType,
|
||||||
|
group.FieldRateMultiplier,
|
||||||
|
group.FieldDailyLimitUsd,
|
||||||
|
group.FieldWeeklyLimitUsd,
|
||||||
|
group.FieldMonthlyLimitUsd,
|
||||||
|
group.FieldImagePrice1k,
|
||||||
|
group.FieldImagePrice2k,
|
||||||
|
group.FieldImagePrice4k,
|
||||||
|
group.FieldClaudeCodeOnly,
|
||||||
|
group.FieldFallbackGroupID,
|
||||||
|
)
|
||||||
|
}).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return apiKeyEntityToService(m), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||||
@@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
|||||||
return int64(count), err
|
return int64(count), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
keys, err := r.activeQuery().
|
||||||
|
Where(apikey.UserIDEQ(userID)).
|
||||||
|
Select(apikey.FieldKey).
|
||||||
|
Strings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
keys, err := r.activeQuery().
|
||||||
|
Where(apikey.GroupIDEQ(groupID)).
|
||||||
|
Select(apikey.FieldKey).
|
||||||
|
Strings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ var (
|
|||||||
return redis.call('ZCARD', key)
|
return redis.call('ZCARD', key)
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
// incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
|
||||||
// KEYS[1] = wait queue key
|
// KEYS[1] = wait queue key
|
||||||
// ARGV[1] = maxWait
|
// ARGV[1] = maxWait
|
||||||
// ARGV[2] = TTL in seconds
|
// ARGV[2] = TTL in seconds
|
||||||
@@ -111,15 +111,13 @@ var (
|
|||||||
|
|
||||||
local newVal = redis.call('INCR', KEYS[1])
|
local newVal = redis.call('INCR', KEYS[1])
|
||||||
|
|
||||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
||||||
if newVal == 1 then
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
|
||||||
end
|
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// incrementAccountWaitScript - account-level wait queue count
|
// incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
|
||||||
incrementAccountWaitScript = redis.NewScript(`
|
incrementAccountWaitScript = redis.NewScript(`
|
||||||
local current = redis.call('GET', KEYS[1])
|
local current = redis.call('GET', KEYS[1])
|
||||||
if current == false then
|
if current == false then
|
||||||
@@ -134,10 +132,8 @@ var (
|
|||||||
|
|
||||||
local newVal = redis.call('INCR', KEYS[1])
|
local newVal = redis.call('INCR', KEYS[1])
|
||||||
|
|
||||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
||||||
if newVal == 1 then
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
|
||||||
end
|
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
|
|||||||
387
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
387
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardAggregationRepository struct {
|
||||||
|
sql sqlExecutor
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||||
|
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||||
|
if sqlDB == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !isPostgresDriver(sqlDB) {
|
||||||
|
log.Printf("[DashboardAggregation] 检测到非 PostgreSQL 驱动,已自动禁用预聚合")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return newDashboardAggregationRepositoryWithSQL(sqlDB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
|
||||||
|
return &dashboardAggregationRepository{sql: sqlq}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPostgresDriver(db *sql.DB) bool {
|
||||||
|
if db == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := db.Driver().(*pq.Driver)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
startUTC := start.UTC()
|
||||||
|
endUTC := end.UTC()
|
||||||
|
if !endUTC.After(startUTC) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hourStart := startUTC.Truncate(time.Hour)
|
||||||
|
hourEnd := endUTC.Truncate(time.Hour)
|
||||||
|
if endUTC.After(hourEnd) {
|
||||||
|
hourEnd = hourEnd.Add(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
dayStart := truncateToDayUTC(startUTC)
|
||||||
|
dayEnd := truncateToDayUTC(endUTC)
|
||||||
|
if endUTC.After(dayEnd) {
|
||||||
|
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||||
|
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||||
|
var ts time.Time
|
||||||
|
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
|
||||||
|
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return time.Unix(0, 0).UTC(), nil
|
||||||
|
}
|
||||||
|
return time.Time{}, err
|
||||||
|
}
|
||||||
|
return ts.UTC(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
|
||||||
|
VALUES (1, $1, NOW())
|
||||||
|
ON CONFLICT (id)
|
||||||
|
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||||
|
hourlyCutoffUTC := hourlyCutoff.UTC()
|
||||||
|
dailyCutoffUTC := dailyCutoff.UTC()
|
||||||
|
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if isPartitioned {
|
||||||
|
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||||
|
}
|
||||||
|
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||||
|
if err != nil || !isPartitioned {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
monthStart := truncateToMonthUTC(now)
|
||||||
|
prevMonth := monthStart.AddDate(0, -1, 0)
|
||||||
|
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||||
|
|
||||||
|
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
|
||||||
|
if err := r.createUsageLogsPartition(ctx, m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
||||||
|
SELECT DISTINCT
|
||||||
|
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||||
|
user_id
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
||||||
|
SELECT DISTINCT
|
||||||
|
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||||
|
user_id
|
||||||
|
FROM usage_dashboard_hourly_users
|
||||||
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
||||||
|
query := `
|
||||||
|
WITH hourly AS (
|
||||||
|
SELECT
|
||||||
|
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||||
|
COUNT(*) AS total_requests,
|
||||||
|
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||||
|
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
user_counts AS (
|
||||||
|
SELECT bucket_start, COUNT(*) AS active_users
|
||||||
|
FROM usage_dashboard_hourly_users
|
||||||
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
|
GROUP BY bucket_start
|
||||||
|
)
|
||||||
|
INSERT INTO usage_dashboard_hourly (
|
||||||
|
bucket_start,
|
||||||
|
total_requests,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
total_duration_ms,
|
||||||
|
active_users,
|
||||||
|
computed_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
hourly.bucket_start,
|
||||||
|
hourly.total_requests,
|
||||||
|
hourly.input_tokens,
|
||||||
|
hourly.output_tokens,
|
||||||
|
hourly.cache_creation_tokens,
|
||||||
|
hourly.cache_read_tokens,
|
||||||
|
hourly.total_cost,
|
||||||
|
hourly.actual_cost,
|
||||||
|
hourly.total_duration_ms,
|
||||||
|
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||||
|
NOW()
|
||||||
|
FROM hourly
|
||||||
|
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
|
||||||
|
ON CONFLICT (bucket_start)
|
||||||
|
DO UPDATE SET
|
||||||
|
total_requests = EXCLUDED.total_requests,
|
||||||
|
input_tokens = EXCLUDED.input_tokens,
|
||||||
|
output_tokens = EXCLUDED.output_tokens,
|
||||||
|
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||||
|
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||||
|
total_cost = EXCLUDED.total_cost,
|
||||||
|
actual_cost = EXCLUDED.actual_cost,
|
||||||
|
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||||
|
active_users = EXCLUDED.active_users,
|
||||||
|
computed_at = EXCLUDED.computed_at
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
||||||
|
query := `
|
||||||
|
WITH daily AS (
|
||||||
|
SELECT
|
||||||
|
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||||
|
COALESCE(SUM(total_requests), 0) AS total_requests,
|
||||||
|
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||||
|
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||||
|
FROM usage_dashboard_hourly
|
||||||
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
|
GROUP BY (bucket_start AT TIME ZONE 'UTC')::date
|
||||||
|
),
|
||||||
|
user_counts AS (
|
||||||
|
SELECT bucket_date, COUNT(*) AS active_users
|
||||||
|
FROM usage_dashboard_daily_users
|
||||||
|
WHERE bucket_date >= $3::date AND bucket_date < $4::date
|
||||||
|
GROUP BY bucket_date
|
||||||
|
)
|
||||||
|
INSERT INTO usage_dashboard_daily (
|
||||||
|
bucket_date,
|
||||||
|
total_requests,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
total_duration_ms,
|
||||||
|
active_users,
|
||||||
|
computed_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
daily.bucket_date,
|
||||||
|
daily.total_requests,
|
||||||
|
daily.input_tokens,
|
||||||
|
daily.output_tokens,
|
||||||
|
daily.cache_creation_tokens,
|
||||||
|
daily.cache_read_tokens,
|
||||||
|
daily.total_cost,
|
||||||
|
daily.actual_cost,
|
||||||
|
daily.total_duration_ms,
|
||||||
|
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||||
|
NOW()
|
||||||
|
FROM daily
|
||||||
|
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
|
||||||
|
ON CONFLICT (bucket_date)
|
||||||
|
DO UPDATE SET
|
||||||
|
total_requests = EXCLUDED.total_requests,
|
||||||
|
input_tokens = EXCLUDED.input_tokens,
|
||||||
|
output_tokens = EXCLUDED.output_tokens,
|
||||||
|
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||||
|
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||||
|
total_cost = EXCLUDED.total_cost,
|
||||||
|
actual_cost = EXCLUDED.actual_cost,
|
||||||
|
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||||
|
active_users = EXCLUDED.active_users,
|
||||||
|
computed_at = EXCLUDED.computed_at
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
|
||||||
|
query := `
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_partitioned_table pt
|
||||||
|
JOIN pg_class c ON c.oid = pt.partrelid
|
||||||
|
WHERE c.relname = 'usage_logs'
|
||||||
|
)
|
||||||
|
`
|
||||||
|
var partitioned bool
|
||||||
|
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return partitioned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
|
||||||
|
rows, err := r.sql.QueryContext(ctx, `
|
||||||
|
SELECT c.relname
|
||||||
|
FROM pg_inherits
|
||||||
|
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
|
||||||
|
JOIN pg_class p ON p.oid = pg_inherits.inhparent
|
||||||
|
WHERE p.relname = 'usage_logs'
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = rows.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
cutoffMonth := truncateToMonthUTC(cutoff)
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(name, "usage_logs_") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
suffix := strings.TrimPrefix(name, "usage_logs_")
|
||||||
|
month, err := time.Parse("200601", suffix)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
month = month.UTC()
|
||||||
|
if month.Before(cutoffMonth) {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
|
||||||
|
monthStart := truncateToMonthUTC(month)
|
||||||
|
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||||
|
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
|
||||||
|
query := fmt.Sprintf(
|
||||||
|
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
|
||||||
|
pq.QuoteIdentifier(name),
|
||||||
|
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
|
||||||
|
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
|
||||||
|
)
|
||||||
|
_, err := r.sql.ExecContext(ctx, query)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateToDayUTC(t time.Time) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateToMonthUTC(t time.Time) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
58
backend/internal/repository/dashboard_cache.go
Normal file
58
backend/internal/repository/dashboard_cache.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const dashboardStatsCacheKey = "dashboard:stats:v1"
|
||||||
|
|
||||||
|
type dashboardCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
keyPrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
|
||||||
|
prefix := "sub2api:"
|
||||||
|
if cfg != nil {
|
||||||
|
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||||
|
}
|
||||||
|
if prefix != "" && !strings.HasSuffix(prefix, ":") {
|
||||||
|
prefix += ":"
|
||||||
|
}
|
||||||
|
return &dashboardCache{
|
||||||
|
rdb: rdb,
|
||||||
|
keyPrefix: prefix,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
|
||||||
|
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
|
||||||
|
if err != nil {
|
||||||
|
if err == redis.Nil {
|
||||||
|
return "", service.ErrDashboardStatsCacheMiss
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||||
|
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCache) buildKey() string {
|
||||||
|
if c.keyPrefix == "" {
|
||||||
|
return dashboardStatsCacheKey
|
||||||
|
}
|
||||||
|
return c.keyPrefix + dashboardStatsCacheKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
|
||||||
|
return c.rdb.Del(ctx, c.buildKey()).Err()
|
||||||
|
}
|
||||||
28
backend/internal/repository/dashboard_cache_test.go
Normal file
28
backend/internal/repository/dashboard_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
|
||||||
|
cache := NewDashboardCache(nil, &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{
|
||||||
|
KeyPrefix: "prod",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
impl, ok := cache.(*dashboardCache)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "prod:", impl.keyPrefix)
|
||||||
|
|
||||||
|
cache = NewDashboardCache(nil, &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{
|
||||||
|
KeyPrefix: "staging:",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
impl, ok = cache.(*dashboardCache)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "staging:", impl.keyPrefix)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
@@ -55,6 +56,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
groupIn.ID = created.ID
|
groupIn.ID = created.ID
|
||||||
groupIn.CreatedAt = created.CreatedAt
|
groupIn.CreatedAt = created.CreatedAt
|
||||||
groupIn.UpdatedAt = created.UpdatedAt
|
groupIn.UpdatedAt = created.UpdatedAt
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||||
}
|
}
|
||||||
@@ -111,12 +115,21 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||||
}
|
}
|
||||||
groupIn.UpdatedAt = updated.UpdatedAt
|
groupIn.UpdatedAt = updated.UpdatedAt
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||||
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
||||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
if err != nil {
|
||||||
|
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||||
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
@@ -246,6 +259,9 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
affected, _ := res.RowsAffected()
|
affected, _ := res.RowsAffected()
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||||
|
}
|
||||||
return affected, nil
|
return affected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -353,6 +369,9 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||||
|
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
return affectedUserIDs, nil
|
return affectedUserIDs, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,23 @@ CREATE TABLE IF NOT EXISTS schema_migrations (
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const atlasSchemaRevisionsTableDDL = `
|
||||||
|
CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
||||||
|
version TEXT PRIMARY KEY,
|
||||||
|
description TEXT NOT NULL,
|
||||||
|
type INTEGER NOT NULL,
|
||||||
|
applied INTEGER NOT NULL DEFAULT 0,
|
||||||
|
total INTEGER NOT NULL DEFAULT 0,
|
||||||
|
executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
execution_time BIGINT NOT NULL DEFAULT 0,
|
||||||
|
error TEXT NULL,
|
||||||
|
error_stmt TEXT NULL,
|
||||||
|
hash TEXT NOT NULL DEFAULT '',
|
||||||
|
partial_hashes TEXT[] NULL,
|
||||||
|
operator_version TEXT NULL
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
||||||
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
||||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||||
@@ -94,6 +111,11 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
|||||||
return fmt.Errorf("create schema_migrations: %w", err)
|
return fmt.Errorf("create schema_migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions)。
|
||||||
|
if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// 获取所有 .sql 迁移文件并按文件名排序。
|
// 获取所有 .sql 迁移文件并按文件名排序。
|
||||||
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
||||||
files, err := fs.Glob(fsys, "*.sql")
|
files, err := fs.Glob(fsys, "*.sql")
|
||||||
@@ -172,6 +194,80 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||||
|
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check schema_migrations: %w", err)
|
||||||
|
}
|
||||||
|
if !hasLegacy {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check atlas_schema_revisions: %w", err)
|
||||||
|
}
|
||||||
|
if !hasAtlas {
|
||||||
|
if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil {
|
||||||
|
return fmt.Errorf("create atlas_schema_revisions: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil {
|
||||||
|
return fmt.Errorf("count atlas_schema_revisions: %w", err)
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
version, description, hash, err := latestMigrationBaseline(fsys)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("atlas baseline version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.ExecContext(ctx, `
|
||||||
|
INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash)
|
||||||
|
VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4)
|
||||||
|
`, version, description, 1, hash); err != nil {
|
||||||
|
return fmt.Errorf("insert atlas baseline: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) {
|
||||||
|
var exists bool
|
||||||
|
err := db.QueryRowContext(ctx, `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM information_schema.tables
|
||||||
|
WHERE table_schema = 'public' AND table_name = $1
|
||||||
|
)
|
||||||
|
`, tableName).Scan(&exists)
|
||||||
|
return exists, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
|
||||||
|
files, err := fs.Glob(fsys, "*.sql")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", err
|
||||||
|
}
|
||||||
|
if len(files) == 0 {
|
||||||
|
return "baseline", "baseline", "", nil
|
||||||
|
}
|
||||||
|
sort.Strings(files)
|
||||||
|
name := files[len(files)-1]
|
||||||
|
contentBytes, err := fs.ReadFile(fsys, name)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", "", err
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(string(contentBytes))
|
||||||
|
sum := sha256.Sum256([]byte(content))
|
||||||
|
hash := hex.EncodeToString(sum[:])
|
||||||
|
version := strings.TrimSuffix(name, ".sql")
|
||||||
|
return version, version, hash, nil
|
||||||
|
}
|
||||||
|
|
||||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||||
|
|||||||
709
backend/internal/repository/ops_repo.go
Normal file
709
backend/internal/repository/ops_repo.go
Normal file
@@ -0,0 +1,709 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
type opsRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpsRepository(db *sql.DB) service.OpsRepository {
|
||||||
|
return &opsRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return 0, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return 0, fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_error_logs (
|
||||||
|
request_id,
|
||||||
|
client_request_id,
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
group_id,
|
||||||
|
client_ip,
|
||||||
|
platform,
|
||||||
|
model,
|
||||||
|
request_path,
|
||||||
|
stream,
|
||||||
|
user_agent,
|
||||||
|
error_phase,
|
||||||
|
error_type,
|
||||||
|
severity,
|
||||||
|
status_code,
|
||||||
|
is_business_limited,
|
||||||
|
is_count_tokens,
|
||||||
|
error_message,
|
||||||
|
error_body,
|
||||||
|
error_source,
|
||||||
|
error_owner,
|
||||||
|
upstream_status_code,
|
||||||
|
upstream_error_message,
|
||||||
|
upstream_error_detail,
|
||||||
|
upstream_errors,
|
||||||
|
duration_ms,
|
||||||
|
time_to_first_token_ms,
|
||||||
|
request_body,
|
||||||
|
request_body_truncated,
|
||||||
|
request_body_bytes,
|
||||||
|
request_headers,
|
||||||
|
is_retryable,
|
||||||
|
retry_count,
|
||||||
|
created_at
|
||||||
|
) VALUES (
|
||||||
|
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35
|
||||||
|
) RETURNING id`
|
||||||
|
|
||||||
|
var id int64
|
||||||
|
err := r.db.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
opsNullString(input.RequestID),
|
||||||
|
opsNullString(input.ClientRequestID),
|
||||||
|
opsNullInt64(input.UserID),
|
||||||
|
opsNullInt64(input.APIKeyID),
|
||||||
|
opsNullInt64(input.AccountID),
|
||||||
|
opsNullInt64(input.GroupID),
|
||||||
|
opsNullString(input.ClientIP),
|
||||||
|
opsNullString(input.Platform),
|
||||||
|
opsNullString(input.Model),
|
||||||
|
opsNullString(input.RequestPath),
|
||||||
|
input.Stream,
|
||||||
|
opsNullString(input.UserAgent),
|
||||||
|
input.ErrorPhase,
|
||||||
|
input.ErrorType,
|
||||||
|
opsNullString(input.Severity),
|
||||||
|
opsNullInt(input.StatusCode),
|
||||||
|
input.IsBusinessLimited,
|
||||||
|
input.IsCountTokens,
|
||||||
|
opsNullString(input.ErrorMessage),
|
||||||
|
opsNullString(input.ErrorBody),
|
||||||
|
opsNullString(input.ErrorSource),
|
||||||
|
opsNullString(input.ErrorOwner),
|
||||||
|
opsNullInt(input.UpstreamStatusCode),
|
||||||
|
opsNullString(input.UpstreamErrorMessage),
|
||||||
|
opsNullString(input.UpstreamErrorDetail),
|
||||||
|
opsNullString(input.UpstreamErrorsJSON),
|
||||||
|
opsNullInt(input.DurationMs),
|
||||||
|
opsNullInt64(input.TimeToFirstTokenMs),
|
||||||
|
opsNullString(input.RequestBodyJSON),
|
||||||
|
input.RequestBodyTruncated,
|
||||||
|
opsNullInt(input.RequestBodyBytes),
|
||||||
|
opsNullString(input.RequestHeadersJSON),
|
||||||
|
input.IsRetryable,
|
||||||
|
input.RetryCount,
|
||||||
|
input.CreatedAt,
|
||||||
|
).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
filter = &service.OpsErrorLogFilter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
page := filter.Page
|
||||||
|
if page <= 0 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
pageSize := filter.PageSize
|
||||||
|
if pageSize <= 0 {
|
||||||
|
pageSize = 20
|
||||||
|
}
|
||||||
|
if pageSize > 500 {
|
||||||
|
pageSize = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
where, args := buildOpsErrorLogsWhere(filter)
|
||||||
|
countSQL := "SELECT COUNT(*) FROM ops_error_logs " + where
|
||||||
|
|
||||||
|
var total int
|
||||||
|
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
argsWithLimit := append(args, pageSize, offset)
|
||||||
|
selectSQL := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
created_at,
|
||||||
|
error_phase,
|
||||||
|
error_type,
|
||||||
|
severity,
|
||||||
|
COALESCE(upstream_status_code, status_code, 0),
|
||||||
|
COALESCE(platform, ''),
|
||||||
|
COALESCE(model, ''),
|
||||||
|
duration_ms,
|
||||||
|
COALESCE(client_request_id, ''),
|
||||||
|
COALESCE(request_id, ''),
|
||||||
|
COALESCE(error_message, ''),
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
group_id,
|
||||||
|
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
|
||||||
|
COALESCE(request_path, ''),
|
||||||
|
stream
|
||||||
|
FROM ops_error_logs
|
||||||
|
` + where + `
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
out := make([]*service.OpsErrorLog, 0, pageSize)
|
||||||
|
for rows.Next() {
|
||||||
|
var item service.OpsErrorLog
|
||||||
|
var latency sql.NullInt64
|
||||||
|
var statusCode sql.NullInt64
|
||||||
|
var clientIP sql.NullString
|
||||||
|
var userID sql.NullInt64
|
||||||
|
var apiKeyID sql.NullInt64
|
||||||
|
var accountID sql.NullInt64
|
||||||
|
var groupID sql.NullInt64
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.ID,
|
||||||
|
&item.CreatedAt,
|
||||||
|
&item.Phase,
|
||||||
|
&item.Type,
|
||||||
|
&item.Severity,
|
||||||
|
&statusCode,
|
||||||
|
&item.Platform,
|
||||||
|
&item.Model,
|
||||||
|
&latency,
|
||||||
|
&item.ClientRequestID,
|
||||||
|
&item.RequestID,
|
||||||
|
&item.Message,
|
||||||
|
&userID,
|
||||||
|
&apiKeyID,
|
||||||
|
&accountID,
|
||||||
|
&groupID,
|
||||||
|
&clientIP,
|
||||||
|
&item.RequestPath,
|
||||||
|
&item.Stream,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if latency.Valid {
|
||||||
|
v := int(latency.Int64)
|
||||||
|
item.LatencyMs = &v
|
||||||
|
}
|
||||||
|
item.StatusCode = int(statusCode.Int64)
|
||||||
|
if clientIP.Valid {
|
||||||
|
s := clientIP.String
|
||||||
|
item.ClientIP = &s
|
||||||
|
}
|
||||||
|
if userID.Valid {
|
||||||
|
v := userID.Int64
|
||||||
|
item.UserID = &v
|
||||||
|
}
|
||||||
|
if apiKeyID.Valid {
|
||||||
|
v := apiKeyID.Int64
|
||||||
|
item.APIKeyID = &v
|
||||||
|
}
|
||||||
|
if accountID.Valid {
|
||||||
|
v := accountID.Int64
|
||||||
|
item.AccountID = &v
|
||||||
|
}
|
||||||
|
if groupID.Valid {
|
||||||
|
v := groupID.Int64
|
||||||
|
item.GroupID = &v
|
||||||
|
}
|
||||||
|
out = append(out, &item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &service.OpsErrorLogList{
|
||||||
|
Errors: out,
|
||||||
|
Total: total,
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service.OpsErrorLogDetail, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if id <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
created_at,
|
||||||
|
error_phase,
|
||||||
|
error_type,
|
||||||
|
severity,
|
||||||
|
COALESCE(upstream_status_code, status_code, 0),
|
||||||
|
COALESCE(platform, ''),
|
||||||
|
COALESCE(model, ''),
|
||||||
|
duration_ms,
|
||||||
|
COALESCE(client_request_id, ''),
|
||||||
|
COALESCE(request_id, ''),
|
||||||
|
COALESCE(error_message, ''),
|
||||||
|
COALESCE(error_body, ''),
|
||||||
|
upstream_status_code,
|
||||||
|
COALESCE(upstream_error_message, ''),
|
||||||
|
COALESCE(upstream_error_detail, ''),
|
||||||
|
COALESCE(upstream_errors::text, ''),
|
||||||
|
is_business_limited,
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
group_id,
|
||||||
|
CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
|
||||||
|
COALESCE(request_path, ''),
|
||||||
|
stream,
|
||||||
|
COALESCE(user_agent, ''),
|
||||||
|
auth_latency_ms,
|
||||||
|
routing_latency_ms,
|
||||||
|
upstream_latency_ms,
|
||||||
|
response_latency_ms,
|
||||||
|
time_to_first_token_ms,
|
||||||
|
COALESCE(request_body::text, ''),
|
||||||
|
request_body_truncated,
|
||||||
|
request_body_bytes,
|
||||||
|
COALESCE(request_headers::text, '')
|
||||||
|
FROM ops_error_logs
|
||||||
|
WHERE id = $1
|
||||||
|
LIMIT 1`
|
||||||
|
|
||||||
|
var out service.OpsErrorLogDetail
|
||||||
|
var latency sql.NullInt64
|
||||||
|
var statusCode sql.NullInt64
|
||||||
|
var upstreamStatusCode sql.NullInt64
|
||||||
|
var clientIP sql.NullString
|
||||||
|
var userID sql.NullInt64
|
||||||
|
var apiKeyID sql.NullInt64
|
||||||
|
var accountID sql.NullInt64
|
||||||
|
var groupID sql.NullInt64
|
||||||
|
var authLatency sql.NullInt64
|
||||||
|
var routingLatency sql.NullInt64
|
||||||
|
var upstreamLatency sql.NullInt64
|
||||||
|
var responseLatency sql.NullInt64
|
||||||
|
var ttft sql.NullInt64
|
||||||
|
var requestBodyBytes sql.NullInt64
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, q, id).Scan(
|
||||||
|
&out.ID,
|
||||||
|
&out.CreatedAt,
|
||||||
|
&out.Phase,
|
||||||
|
&out.Type,
|
||||||
|
&out.Severity,
|
||||||
|
&statusCode,
|
||||||
|
&out.Platform,
|
||||||
|
&out.Model,
|
||||||
|
&latency,
|
||||||
|
&out.ClientRequestID,
|
||||||
|
&out.RequestID,
|
||||||
|
&out.Message,
|
||||||
|
&out.ErrorBody,
|
||||||
|
&upstreamStatusCode,
|
||||||
|
&out.UpstreamErrorMessage,
|
||||||
|
&out.UpstreamErrorDetail,
|
||||||
|
&out.UpstreamErrors,
|
||||||
|
&out.IsBusinessLimited,
|
||||||
|
&userID,
|
||||||
|
&apiKeyID,
|
||||||
|
&accountID,
|
||||||
|
&groupID,
|
||||||
|
&clientIP,
|
||||||
|
&out.RequestPath,
|
||||||
|
&out.Stream,
|
||||||
|
&out.UserAgent,
|
||||||
|
&authLatency,
|
||||||
|
&routingLatency,
|
||||||
|
&upstreamLatency,
|
||||||
|
&responseLatency,
|
||||||
|
&ttft,
|
||||||
|
&out.RequestBody,
|
||||||
|
&out.RequestBodyTruncated,
|
||||||
|
&requestBodyBytes,
|
||||||
|
&out.RequestHeaders,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out.StatusCode = int(statusCode.Int64)
|
||||||
|
if latency.Valid {
|
||||||
|
v := int(latency.Int64)
|
||||||
|
out.LatencyMs = &v
|
||||||
|
}
|
||||||
|
if clientIP.Valid {
|
||||||
|
s := clientIP.String
|
||||||
|
out.ClientIP = &s
|
||||||
|
}
|
||||||
|
if upstreamStatusCode.Valid && upstreamStatusCode.Int64 > 0 {
|
||||||
|
v := int(upstreamStatusCode.Int64)
|
||||||
|
out.UpstreamStatusCode = &v
|
||||||
|
}
|
||||||
|
if userID.Valid {
|
||||||
|
v := userID.Int64
|
||||||
|
out.UserID = &v
|
||||||
|
}
|
||||||
|
if apiKeyID.Valid {
|
||||||
|
v := apiKeyID.Int64
|
||||||
|
out.APIKeyID = &v
|
||||||
|
}
|
||||||
|
if accountID.Valid {
|
||||||
|
v := accountID.Int64
|
||||||
|
out.AccountID = &v
|
||||||
|
}
|
||||||
|
if groupID.Valid {
|
||||||
|
v := groupID.Int64
|
||||||
|
out.GroupID = &v
|
||||||
|
}
|
||||||
|
if authLatency.Valid {
|
||||||
|
v := authLatency.Int64
|
||||||
|
out.AuthLatencyMs = &v
|
||||||
|
}
|
||||||
|
if routingLatency.Valid {
|
||||||
|
v := routingLatency.Int64
|
||||||
|
out.RoutingLatencyMs = &v
|
||||||
|
}
|
||||||
|
if upstreamLatency.Valid {
|
||||||
|
v := upstreamLatency.Int64
|
||||||
|
out.UpstreamLatencyMs = &v
|
||||||
|
}
|
||||||
|
if responseLatency.Valid {
|
||||||
|
v := responseLatency.Int64
|
||||||
|
out.ResponseLatencyMs = &v
|
||||||
|
}
|
||||||
|
if ttft.Valid {
|
||||||
|
v := ttft.Int64
|
||||||
|
out.TimeToFirstTokenMs = &v
|
||||||
|
}
|
||||||
|
if requestBodyBytes.Valid {
|
||||||
|
v := int(requestBodyBytes.Int64)
|
||||||
|
out.RequestBodyBytes = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize request_body to empty string when stored as JSON null.
|
||||||
|
out.RequestBody = strings.TrimSpace(out.RequestBody)
|
||||||
|
if out.RequestBody == "null" {
|
||||||
|
out.RequestBody = ""
|
||||||
|
}
|
||||||
|
// Normalize request_headers to empty string when stored as JSON null.
|
||||||
|
out.RequestHeaders = strings.TrimSpace(out.RequestHeaders)
|
||||||
|
if out.RequestHeaders == "null" {
|
||||||
|
out.RequestHeaders = ""
|
||||||
|
}
|
||||||
|
// Normalize upstream_errors to empty string when stored as JSON null.
|
||||||
|
out.UpstreamErrors = strings.TrimSpace(out.UpstreamErrors)
|
||||||
|
if out.UpstreamErrors == "null" {
|
||||||
|
out.UpstreamErrors = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) InsertRetryAttempt(ctx context.Context, input *service.OpsInsertRetryAttemptInput) (int64, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return 0, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return 0, fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
if input.SourceErrorID <= 0 {
|
||||||
|
return 0, fmt.Errorf("invalid source_error_id")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(input.Mode) == "" {
|
||||||
|
return 0, fmt.Errorf("invalid mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_retry_attempts (
|
||||||
|
requested_by_user_id,
|
||||||
|
source_error_id,
|
||||||
|
mode,
|
||||||
|
pinned_account_id,
|
||||||
|
status,
|
||||||
|
started_at
|
||||||
|
) VALUES (
|
||||||
|
$1,$2,$3,$4,$5,$6
|
||||||
|
) RETURNING id`
|
||||||
|
|
||||||
|
var id int64
|
||||||
|
err := r.db.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
opsNullInt64(&input.RequestedByUserID),
|
||||||
|
input.SourceErrorID,
|
||||||
|
strings.TrimSpace(input.Mode),
|
||||||
|
opsNullInt64(input.PinnedAccountID),
|
||||||
|
strings.TrimSpace(input.Status),
|
||||||
|
input.StartedAt,
|
||||||
|
).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) UpdateRetryAttempt(ctx context.Context, input *service.OpsUpdateRetryAttemptInput) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
if input.ID <= 0 {
|
||||||
|
return fmt.Errorf("invalid id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
UPDATE ops_retry_attempts
|
||||||
|
SET
|
||||||
|
status = $2,
|
||||||
|
finished_at = $3,
|
||||||
|
duration_ms = $4,
|
||||||
|
result_request_id = $5,
|
||||||
|
result_error_id = $6,
|
||||||
|
error_message = $7
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
input.ID,
|
||||||
|
strings.TrimSpace(input.Status),
|
||||||
|
nullTime(input.FinishedAt),
|
||||||
|
input.DurationMs,
|
||||||
|
opsNullString(input.ResultRequestID),
|
||||||
|
opsNullInt64(input.ResultErrorID),
|
||||||
|
opsNullString(input.ErrorMessage),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*service.OpsRetryAttempt, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if sourceErrorID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid source_error_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
created_at,
|
||||||
|
COALESCE(requested_by_user_id, 0),
|
||||||
|
source_error_id,
|
||||||
|
COALESCE(mode, ''),
|
||||||
|
pinned_account_id,
|
||||||
|
COALESCE(status, ''),
|
||||||
|
started_at,
|
||||||
|
finished_at,
|
||||||
|
duration_ms,
|
||||||
|
result_request_id,
|
||||||
|
result_error_id,
|
||||||
|
error_message
|
||||||
|
FROM ops_retry_attempts
|
||||||
|
WHERE source_error_id = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT 1`
|
||||||
|
|
||||||
|
var out service.OpsRetryAttempt
|
||||||
|
var pinnedAccountID sql.NullInt64
|
||||||
|
var requestedBy sql.NullInt64
|
||||||
|
var startedAt sql.NullTime
|
||||||
|
var finishedAt sql.NullTime
|
||||||
|
var durationMs sql.NullInt64
|
||||||
|
var resultRequestID sql.NullString
|
||||||
|
var resultErrorID sql.NullInt64
|
||||||
|
var errorMessage sql.NullString
|
||||||
|
|
||||||
|
err := r.db.QueryRowContext(ctx, q, sourceErrorID).Scan(
|
||||||
|
&out.ID,
|
||||||
|
&out.CreatedAt,
|
||||||
|
&requestedBy,
|
||||||
|
&out.SourceErrorID,
|
||||||
|
&out.Mode,
|
||||||
|
&pinnedAccountID,
|
||||||
|
&out.Status,
|
||||||
|
&startedAt,
|
||||||
|
&finishedAt,
|
||||||
|
&durationMs,
|
||||||
|
&resultRequestID,
|
||||||
|
&resultErrorID,
|
||||||
|
&errorMessage,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out.RequestedByUserID = requestedBy.Int64
|
||||||
|
if pinnedAccountID.Valid {
|
||||||
|
v := pinnedAccountID.Int64
|
||||||
|
out.PinnedAccountID = &v
|
||||||
|
}
|
||||||
|
if startedAt.Valid {
|
||||||
|
t := startedAt.Time
|
||||||
|
out.StartedAt = &t
|
||||||
|
}
|
||||||
|
if finishedAt.Valid {
|
||||||
|
t := finishedAt.Time
|
||||||
|
out.FinishedAt = &t
|
||||||
|
}
|
||||||
|
if durationMs.Valid {
|
||||||
|
v := durationMs.Int64
|
||||||
|
out.DurationMs = &v
|
||||||
|
}
|
||||||
|
if resultRequestID.Valid {
|
||||||
|
s := resultRequestID.String
|
||||||
|
out.ResultRequestID = &s
|
||||||
|
}
|
||||||
|
if resultErrorID.Valid {
|
||||||
|
v := resultErrorID.Int64
|
||||||
|
out.ResultErrorID = &v
|
||||||
|
}
|
||||||
|
if errorMessage.Valid {
|
||||||
|
s := errorMessage.String
|
||||||
|
out.ErrorMessage = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullTime(t time.Time) sql.NullTime {
|
||||||
|
if t.IsZero() {
|
||||||
|
return sql.NullTime{}
|
||||||
|
}
|
||||||
|
return sql.NullTime{Time: t, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||||
|
clauses := make([]string, 0, 8)
|
||||||
|
args := make([]any, 0, 8)
|
||||||
|
clauses = append(clauses, "1=1")
|
||||||
|
|
||||||
|
phaseFilter := ""
|
||||||
|
if filter != nil {
|
||||||
|
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
|
||||||
|
}
|
||||||
|
// ops_error_logs primarily stores client-visible error requests (status>=400),
|
||||||
|
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
|
||||||
|
// By default, keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
||||||
|
if phaseFilter != "upstream" {
|
||||||
|
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||||
|
args = append(args, filter.StartTime.UTC())
|
||||||
|
clauses = append(clauses, "created_at >= $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||||
|
args = append(args, filter.EndTime.UTC())
|
||||||
|
// Keep time-window semantics consistent with other ops queries: [start, end)
|
||||||
|
clauses = append(clauses, "created_at < $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if p := strings.TrimSpace(filter.Platform); p != "" {
|
||||||
|
args = append(args, p)
|
||||||
|
clauses = append(clauses, "platform = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||||
|
args = append(args, *filter.GroupID)
|
||||||
|
clauses = append(clauses, "group_id = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||||
|
args = append(args, *filter.AccountID)
|
||||||
|
clauses = append(clauses, "account_id = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if phase := phaseFilter; phase != "" {
|
||||||
|
args = append(args, phase)
|
||||||
|
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if len(filter.StatusCodes) > 0 {
|
||||||
|
args = append(args, pq.Array(filter.StatusCodes))
|
||||||
|
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||||
|
}
|
||||||
|
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||||
|
like := "%" + q + "%"
|
||||||
|
args = append(args, like)
|
||||||
|
n := itoa(len(args))
|
||||||
|
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
||||||
|
}
|
||||||
|
|
||||||
|
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers for nullable args
|
||||||
|
func opsNullString(v any) any {
|
||||||
|
switch s := v.(type) {
|
||||||
|
case nil:
|
||||||
|
return sql.NullString{}
|
||||||
|
case *string:
|
||||||
|
if s == nil || strings.TrimSpace(*s) == "" {
|
||||||
|
return sql.NullString{}
|
||||||
|
}
|
||||||
|
return sql.NullString{String: strings.TrimSpace(*s), Valid: true}
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(s) == "" {
|
||||||
|
return sql.NullString{}
|
||||||
|
}
|
||||||
|
return sql.NullString{String: strings.TrimSpace(s), Valid: true}
|
||||||
|
default:
|
||||||
|
return sql.NullString{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsNullInt64(v *int64) any {
|
||||||
|
if v == nil || *v == 0 {
|
||||||
|
return sql.NullInt64{}
|
||||||
|
}
|
||||||
|
return sql.NullInt64{Int64: *v, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsNullInt(v any) any {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case nil:
|
||||||
|
return sql.NullInt64{}
|
||||||
|
case *int:
|
||||||
|
if n == nil || *n == 0 {
|
||||||
|
return sql.NullInt64{}
|
||||||
|
}
|
||||||
|
return sql.NullInt64{Int64: int64(*n), Valid: true}
|
||||||
|
case *int64:
|
||||||
|
if n == nil || *n == 0 {
|
||||||
|
return sql.NullInt64{}
|
||||||
|
}
|
||||||
|
return sql.NullInt64{Int64: *n, Valid: true}
|
||||||
|
case int:
|
||||||
|
if n == 0 {
|
||||||
|
return sql.NullInt64{}
|
||||||
|
}
|
||||||
|
return sql.NullInt64{Int64: int64(n), Valid: true}
|
||||||
|
default:
|
||||||
|
return sql.NullInt64{}
|
||||||
|
}
|
||||||
|
}
|
||||||
689
backend/internal/repository/ops_repo_alerts.go
Normal file
689
backend/internal/repository/ops_repo_alerts.go
Normal file
@@ -0,0 +1,689 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) ListAlertRules(ctx context.Context) ([]*service.OpsAlertRule, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
COALESCE(description, ''),
|
||||||
|
enabled,
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
metric_type,
|
||||||
|
operator,
|
||||||
|
threshold,
|
||||||
|
window_minutes,
|
||||||
|
sustained_minutes,
|
||||||
|
cooldown_minutes,
|
||||||
|
COALESCE(notify_email, true),
|
||||||
|
filters,
|
||||||
|
last_triggered_at,
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
|
FROM ops_alert_rules
|
||||||
|
ORDER BY id DESC`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
out := []*service.OpsAlertRule{}
|
||||||
|
for rows.Next() {
|
||||||
|
var rule service.OpsAlertRule
|
||||||
|
var filtersRaw []byte
|
||||||
|
var lastTriggeredAt sql.NullTime
|
||||||
|
if err := rows.Scan(
|
||||||
|
&rule.ID,
|
||||||
|
&rule.Name,
|
||||||
|
&rule.Description,
|
||||||
|
&rule.Enabled,
|
||||||
|
&rule.Severity,
|
||||||
|
&rule.MetricType,
|
||||||
|
&rule.Operator,
|
||||||
|
&rule.Threshold,
|
||||||
|
&rule.WindowMinutes,
|
||||||
|
&rule.SustainedMinutes,
|
||||||
|
&rule.CooldownMinutes,
|
||||||
|
&rule.NotifyEmail,
|
||||||
|
&filtersRaw,
|
||||||
|
&lastTriggeredAt,
|
||||||
|
&rule.CreatedAt,
|
||||||
|
&rule.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if lastTriggeredAt.Valid {
|
||||||
|
v := lastTriggeredAt.Time
|
||||||
|
rule.LastTriggeredAt = &v
|
||||||
|
}
|
||||||
|
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||||
|
var decoded map[string]any
|
||||||
|
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||||
|
rule.Filters = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, &rule)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) CreateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return nil, fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
|
||||||
|
filtersArg, err := opsNullJSONMap(input.Filters)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_alert_rules (
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
enabled,
|
||||||
|
severity,
|
||||||
|
metric_type,
|
||||||
|
operator,
|
||||||
|
threshold,
|
||||||
|
window_minutes,
|
||||||
|
sustained_minutes,
|
||||||
|
cooldown_minutes,
|
||||||
|
notify_email,
|
||||||
|
filters,
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
|
) VALUES (
|
||||||
|
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,NOW(),NOW()
|
||||||
|
)
|
||||||
|
RETURNING
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
COALESCE(description, ''),
|
||||||
|
enabled,
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
metric_type,
|
||||||
|
operator,
|
||||||
|
threshold,
|
||||||
|
window_minutes,
|
||||||
|
sustained_minutes,
|
||||||
|
cooldown_minutes,
|
||||||
|
COALESCE(notify_email, true),
|
||||||
|
filters,
|
||||||
|
last_triggered_at,
|
||||||
|
created_at,
|
||||||
|
updated_at`
|
||||||
|
|
||||||
|
var out service.OpsAlertRule
|
||||||
|
var filtersRaw []byte
|
||||||
|
var lastTriggeredAt sql.NullTime
|
||||||
|
|
||||||
|
if err := r.db.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
strings.TrimSpace(input.Name),
|
||||||
|
strings.TrimSpace(input.Description),
|
||||||
|
input.Enabled,
|
||||||
|
strings.TrimSpace(input.Severity),
|
||||||
|
strings.TrimSpace(input.MetricType),
|
||||||
|
strings.TrimSpace(input.Operator),
|
||||||
|
input.Threshold,
|
||||||
|
input.WindowMinutes,
|
||||||
|
input.SustainedMinutes,
|
||||||
|
input.CooldownMinutes,
|
||||||
|
input.NotifyEmail,
|
||||||
|
filtersArg,
|
||||||
|
).Scan(
|
||||||
|
&out.ID,
|
||||||
|
&out.Name,
|
||||||
|
&out.Description,
|
||||||
|
&out.Enabled,
|
||||||
|
&out.Severity,
|
||||||
|
&out.MetricType,
|
||||||
|
&out.Operator,
|
||||||
|
&out.Threshold,
|
||||||
|
&out.WindowMinutes,
|
||||||
|
&out.SustainedMinutes,
|
||||||
|
&out.CooldownMinutes,
|
||||||
|
&out.NotifyEmail,
|
||||||
|
&filtersRaw,
|
||||||
|
&lastTriggeredAt,
|
||||||
|
&out.CreatedAt,
|
||||||
|
&out.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if lastTriggeredAt.Valid {
|
||||||
|
v := lastTriggeredAt.Time
|
||||||
|
out.LastTriggeredAt = &v
|
||||||
|
}
|
||||||
|
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||||
|
var decoded map[string]any
|
||||||
|
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||||
|
out.Filters = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) UpdateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return nil, fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
if input.ID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid id")
|
||||||
|
}
|
||||||
|
|
||||||
|
filtersArg, err := opsNullJSONMap(input.Filters)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
UPDATE ops_alert_rules
|
||||||
|
SET
|
||||||
|
name = $2,
|
||||||
|
description = $3,
|
||||||
|
enabled = $4,
|
||||||
|
severity = $5,
|
||||||
|
metric_type = $6,
|
||||||
|
operator = $7,
|
||||||
|
threshold = $8,
|
||||||
|
window_minutes = $9,
|
||||||
|
sustained_minutes = $10,
|
||||||
|
cooldown_minutes = $11,
|
||||||
|
notify_email = $12,
|
||||||
|
filters = $13,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
RETURNING
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
COALESCE(description, ''),
|
||||||
|
enabled,
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
metric_type,
|
||||||
|
operator,
|
||||||
|
threshold,
|
||||||
|
window_minutes,
|
||||||
|
sustained_minutes,
|
||||||
|
cooldown_minutes,
|
||||||
|
COALESCE(notify_email, true),
|
||||||
|
filters,
|
||||||
|
last_triggered_at,
|
||||||
|
created_at,
|
||||||
|
updated_at`
|
||||||
|
|
||||||
|
var out service.OpsAlertRule
|
||||||
|
var filtersRaw []byte
|
||||||
|
var lastTriggeredAt sql.NullTime
|
||||||
|
|
||||||
|
if err := r.db.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
input.ID,
|
||||||
|
strings.TrimSpace(input.Name),
|
||||||
|
strings.TrimSpace(input.Description),
|
||||||
|
input.Enabled,
|
||||||
|
strings.TrimSpace(input.Severity),
|
||||||
|
strings.TrimSpace(input.MetricType),
|
||||||
|
strings.TrimSpace(input.Operator),
|
||||||
|
input.Threshold,
|
||||||
|
input.WindowMinutes,
|
||||||
|
input.SustainedMinutes,
|
||||||
|
input.CooldownMinutes,
|
||||||
|
input.NotifyEmail,
|
||||||
|
filtersArg,
|
||||||
|
).Scan(
|
||||||
|
&out.ID,
|
||||||
|
&out.Name,
|
||||||
|
&out.Description,
|
||||||
|
&out.Enabled,
|
||||||
|
&out.Severity,
|
||||||
|
&out.MetricType,
|
||||||
|
&out.Operator,
|
||||||
|
&out.Threshold,
|
||||||
|
&out.WindowMinutes,
|
||||||
|
&out.SustainedMinutes,
|
||||||
|
&out.CooldownMinutes,
|
||||||
|
&out.NotifyEmail,
|
||||||
|
&filtersRaw,
|
||||||
|
&lastTriggeredAt,
|
||||||
|
&out.CreatedAt,
|
||||||
|
&out.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastTriggeredAt.Valid {
|
||||||
|
v := lastTriggeredAt.Time
|
||||||
|
out.LastTriggeredAt = &v
|
||||||
|
}
|
||||||
|
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||||
|
var decoded map[string]any
|
||||||
|
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||||
|
out.Filters = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) DeleteAlertRule(ctx context.Context, id int64) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if id <= 0 {
|
||||||
|
return fmt.Errorf("invalid id")
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := r.db.ExecContext(ctx, "DELETE FROM ops_alert_rules WHERE id = $1", id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return sql.ErrNoRows
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) ListAlertEvents(ctx context.Context, filter *service.OpsAlertEventFilter) ([]*service.OpsAlertEvent, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
filter = &service.OpsAlertEventFilter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := filter.Limit
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
if limit > 500 {
|
||||||
|
limit = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
where, args := buildOpsAlertEventsWhere(filter)
|
||||||
|
args = append(args, limit)
|
||||||
|
limitArg := "$" + itoa(len(args))
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
COALESCE(rule_id, 0),
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
COALESCE(status, ''),
|
||||||
|
COALESCE(title, ''),
|
||||||
|
COALESCE(description, ''),
|
||||||
|
metric_value,
|
||||||
|
threshold_value,
|
||||||
|
dimensions,
|
||||||
|
fired_at,
|
||||||
|
resolved_at,
|
||||||
|
email_sent,
|
||||||
|
created_at
|
||||||
|
FROM ops_alert_events
|
||||||
|
` + where + `
|
||||||
|
ORDER BY fired_at DESC
|
||||||
|
LIMIT ` + limitArg
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
out := []*service.OpsAlertEvent{}
|
||||||
|
for rows.Next() {
|
||||||
|
var ev service.OpsAlertEvent
|
||||||
|
var metricValue sql.NullFloat64
|
||||||
|
var thresholdValue sql.NullFloat64
|
||||||
|
var dimensionsRaw []byte
|
||||||
|
var resolvedAt sql.NullTime
|
||||||
|
if err := rows.Scan(
|
||||||
|
&ev.ID,
|
||||||
|
&ev.RuleID,
|
||||||
|
&ev.Severity,
|
||||||
|
&ev.Status,
|
||||||
|
&ev.Title,
|
||||||
|
&ev.Description,
|
||||||
|
&metricValue,
|
||||||
|
&thresholdValue,
|
||||||
|
&dimensionsRaw,
|
||||||
|
&ev.FiredAt,
|
||||||
|
&resolvedAt,
|
||||||
|
&ev.EmailSent,
|
||||||
|
&ev.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if metricValue.Valid {
|
||||||
|
v := metricValue.Float64
|
||||||
|
ev.MetricValue = &v
|
||||||
|
}
|
||||||
|
if thresholdValue.Valid {
|
||||||
|
v := thresholdValue.Float64
|
||||||
|
ev.ThresholdValue = &v
|
||||||
|
}
|
||||||
|
if resolvedAt.Valid {
|
||||||
|
v := resolvedAt.Time
|
||||||
|
ev.ResolvedAt = &v
|
||||||
|
}
|
||||||
|
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
|
||||||
|
var decoded map[string]any
|
||||||
|
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
|
||||||
|
ev.Dimensions = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, &ev)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if ruleID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid rule id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
COALESCE(rule_id, 0),
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
COALESCE(status, ''),
|
||||||
|
COALESCE(title, ''),
|
||||||
|
COALESCE(description, ''),
|
||||||
|
metric_value,
|
||||||
|
threshold_value,
|
||||||
|
dimensions,
|
||||||
|
fired_at,
|
||||||
|
resolved_at,
|
||||||
|
email_sent,
|
||||||
|
created_at
|
||||||
|
FROM ops_alert_events
|
||||||
|
WHERE rule_id = $1 AND status = $2
|
||||||
|
ORDER BY fired_at DESC
|
||||||
|
LIMIT 1`
|
||||||
|
|
||||||
|
row := r.db.QueryRowContext(ctx, q, ruleID, service.OpsAlertStatusFiring)
|
||||||
|
ev, err := scanOpsAlertEvent(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ev, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if ruleID <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid rule id")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
COALESCE(rule_id, 0),
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
COALESCE(status, ''),
|
||||||
|
COALESCE(title, ''),
|
||||||
|
COALESCE(description, ''),
|
||||||
|
metric_value,
|
||||||
|
threshold_value,
|
||||||
|
dimensions,
|
||||||
|
fired_at,
|
||||||
|
resolved_at,
|
||||||
|
email_sent,
|
||||||
|
created_at
|
||||||
|
FROM ops_alert_events
|
||||||
|
WHERE rule_id = $1
|
||||||
|
ORDER BY fired_at DESC
|
||||||
|
LIMIT 1`
|
||||||
|
|
||||||
|
row := r.db.QueryRowContext(ctx, q, ruleID)
|
||||||
|
ev, err := scanOpsAlertEvent(row)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ev, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) CreateAlertEvent(ctx context.Context, event *service.OpsAlertEvent) (*service.OpsAlertEvent, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if event == nil {
|
||||||
|
return nil, fmt.Errorf("nil event")
|
||||||
|
}
|
||||||
|
|
||||||
|
dimensionsArg, err := opsNullJSONMap(event.Dimensions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_alert_events (
|
||||||
|
rule_id,
|
||||||
|
severity,
|
||||||
|
status,
|
||||||
|
title,
|
||||||
|
description,
|
||||||
|
metric_value,
|
||||||
|
threshold_value,
|
||||||
|
dimensions,
|
||||||
|
fired_at,
|
||||||
|
resolved_at,
|
||||||
|
email_sent,
|
||||||
|
created_at
|
||||||
|
) VALUES (
|
||||||
|
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,NOW()
|
||||||
|
)
|
||||||
|
RETURNING
|
||||||
|
id,
|
||||||
|
COALESCE(rule_id, 0),
|
||||||
|
COALESCE(severity, ''),
|
||||||
|
COALESCE(status, ''),
|
||||||
|
COALESCE(title, ''),
|
||||||
|
COALESCE(description, ''),
|
||||||
|
metric_value,
|
||||||
|
threshold_value,
|
||||||
|
dimensions,
|
||||||
|
fired_at,
|
||||||
|
resolved_at,
|
||||||
|
email_sent,
|
||||||
|
created_at`
|
||||||
|
|
||||||
|
row := r.db.QueryRowContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
opsNullInt64(&event.RuleID),
|
||||||
|
opsNullString(event.Severity),
|
||||||
|
opsNullString(event.Status),
|
||||||
|
opsNullString(event.Title),
|
||||||
|
opsNullString(event.Description),
|
||||||
|
opsNullFloat64(event.MetricValue),
|
||||||
|
opsNullFloat64(event.ThresholdValue),
|
||||||
|
dimensionsArg,
|
||||||
|
event.FiredAt,
|
||||||
|
opsNullTime(event.ResolvedAt),
|
||||||
|
event.EmailSent,
|
||||||
|
)
|
||||||
|
return scanOpsAlertEvent(row)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if eventID <= 0 {
|
||||||
|
return fmt.Errorf("invalid event id")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(status) == "" {
|
||||||
|
return fmt.Errorf("invalid status")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
UPDATE ops_alert_events
|
||||||
|
SET status = $2,
|
||||||
|
resolved_at = $3
|
||||||
|
WHERE id = $1`
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(ctx, q, eventID, strings.TrimSpace(status), opsNullTime(resolvedAt))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if eventID <= 0 {
|
||||||
|
return fmt.Errorf("invalid event id")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(ctx, "UPDATE ops_alert_events SET email_sent = $2 WHERE id = $1", eventID, emailSent)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type opsAlertEventRow interface {
|
||||||
|
Scan(dest ...any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
|
||||||
|
var ev service.OpsAlertEvent
|
||||||
|
var metricValue sql.NullFloat64
|
||||||
|
var thresholdValue sql.NullFloat64
|
||||||
|
var dimensionsRaw []byte
|
||||||
|
var resolvedAt sql.NullTime
|
||||||
|
|
||||||
|
if err := row.Scan(
|
||||||
|
&ev.ID,
|
||||||
|
&ev.RuleID,
|
||||||
|
&ev.Severity,
|
||||||
|
&ev.Status,
|
||||||
|
&ev.Title,
|
||||||
|
&ev.Description,
|
||||||
|
&metricValue,
|
||||||
|
&thresholdValue,
|
||||||
|
&dimensionsRaw,
|
||||||
|
&ev.FiredAt,
|
||||||
|
&resolvedAt,
|
||||||
|
&ev.EmailSent,
|
||||||
|
&ev.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if metricValue.Valid {
|
||||||
|
v := metricValue.Float64
|
||||||
|
ev.MetricValue = &v
|
||||||
|
}
|
||||||
|
if thresholdValue.Valid {
|
||||||
|
v := thresholdValue.Float64
|
||||||
|
ev.ThresholdValue = &v
|
||||||
|
}
|
||||||
|
if resolvedAt.Valid {
|
||||||
|
v := resolvedAt.Time
|
||||||
|
ev.ResolvedAt = &v
|
||||||
|
}
|
||||||
|
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
|
||||||
|
var decoded map[string]any
|
||||||
|
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
|
||||||
|
ev.Dimensions = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &ev, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []any) {
|
||||||
|
clauses := []string{"1=1"}
|
||||||
|
args := []any{}
|
||||||
|
|
||||||
|
if filter == nil {
|
||||||
|
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := strings.TrimSpace(filter.Status); status != "" {
|
||||||
|
args = append(args, status)
|
||||||
|
clauses = append(clauses, "status = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if severity := strings.TrimSpace(filter.Severity); severity != "" {
|
||||||
|
args = append(args, severity)
|
||||||
|
clauses = append(clauses, "severity = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||||
|
args = append(args, *filter.StartTime)
|
||||||
|
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||||
|
args = append(args, *filter.EndTime)
|
||||||
|
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
|
||||||
|
if platform := strings.TrimSpace(filter.Platform); platform != "" {
|
||||||
|
args = append(args, platform)
|
||||||
|
clauses = append(clauses, "(dimensions->>'platform') = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||||
|
args = append(args, fmt.Sprintf("%d", *filter.GroupID))
|
||||||
|
clauses = append(clauses, "(dimensions->>'group_id') = $"+itoa(len(args)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsNullJSONMap(v map[string]any) (any, error) {
|
||||||
|
if v == nil {
|
||||||
|
return sql.NullString{}, nil
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(b) == 0 {
|
||||||
|
return sql.NullString{}, nil
|
||||||
|
}
|
||||||
|
return sql.NullString{String: string(b), Valid: true}, nil
|
||||||
|
}
|
||||||
1015
backend/internal/repository/ops_repo_dashboard.go
Normal file
1015
backend/internal/repository/ops_repo_dashboard.go
Normal file
File diff suppressed because it is too large
Load Diff
79
backend/internal/repository/ops_repo_histograms.go
Normal file
79
backend/internal/repository/ops_repo_histograms.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) GetLatencyHistogram(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsLatencyHistogramResponse, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
return nil, fmt.Errorf("nil filter")
|
||||||
|
}
|
||||||
|
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||||
|
return nil, fmt.Errorf("start_time/end_time required")
|
||||||
|
}
|
||||||
|
|
||||||
|
start := filter.StartTime.UTC()
|
||||||
|
end := filter.EndTime.UTC()
|
||||||
|
|
||||||
|
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
|
||||||
|
rangeExpr := latencyHistogramRangeCaseExpr("ul.duration_ms")
|
||||||
|
orderExpr := latencyHistogramRangeOrderCaseExpr("ul.duration_ms")
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
` + rangeExpr + ` AS range,
|
||||||
|
COALESCE(COUNT(*), 0) AS count,
|
||||||
|
` + orderExpr + ` AS ord
|
||||||
|
FROM usage_logs ul
|
||||||
|
` + join + `
|
||||||
|
` + where + `
|
||||||
|
AND ul.duration_ms IS NOT NULL
|
||||||
|
GROUP BY 1, 3
|
||||||
|
ORDER BY 3 ASC`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
counts := make(map[string]int64, len(latencyHistogramOrderedRanges))
|
||||||
|
var total int64
|
||||||
|
for rows.Next() {
|
||||||
|
var label string
|
||||||
|
var count int64
|
||||||
|
var _ord int
|
||||||
|
if err := rows.Scan(&label, &count, &_ord); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts[label] = count
|
||||||
|
total += count
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
buckets := make([]*service.OpsLatencyHistogramBucket, 0, len(latencyHistogramOrderedRanges))
|
||||||
|
for _, label := range latencyHistogramOrderedRanges {
|
||||||
|
buckets = append(buckets, &service.OpsLatencyHistogramBucket{
|
||||||
|
Range: label,
|
||||||
|
Count: counts[label],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &service.OpsLatencyHistogramResponse{
|
||||||
|
StartTime: start,
|
||||||
|
EndTime: end,
|
||||||
|
Platform: strings.TrimSpace(filter.Platform),
|
||||||
|
GroupID: filter.GroupID,
|
||||||
|
TotalRequests: total,
|
||||||
|
Buckets: buckets,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type latencyHistogramBucket struct {
|
||||||
|
upperMs int
|
||||||
|
label string
|
||||||
|
}
|
||||||
|
|
||||||
|
var latencyHistogramBuckets = []latencyHistogramBucket{
|
||||||
|
{upperMs: 100, label: "0-100ms"},
|
||||||
|
{upperMs: 200, label: "100-200ms"},
|
||||||
|
{upperMs: 500, label: "200-500ms"},
|
||||||
|
{upperMs: 1000, label: "500-1000ms"},
|
||||||
|
{upperMs: 2000, label: "1000-2000ms"},
|
||||||
|
{upperMs: 0, label: "2000ms+"}, // default bucket
|
||||||
|
}
|
||||||
|
|
||||||
|
var latencyHistogramOrderedRanges = func() []string {
|
||||||
|
out := make([]string, 0, len(latencyHistogramBuckets))
|
||||||
|
for _, b := range latencyHistogramBuckets {
|
||||||
|
out = append(out, b.label)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}()
|
||||||
|
|
||||||
|
func latencyHistogramRangeCaseExpr(column string) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
_, _ = sb.WriteString("CASE\n")
|
||||||
|
|
||||||
|
for _, b := range latencyHistogramBuckets {
|
||||||
|
if b.upperMs <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default bucket.
|
||||||
|
last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1]
|
||||||
|
_, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label))
|
||||||
|
_, _ = sb.WriteString("END")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func latencyHistogramRangeOrderCaseExpr(column string) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
_, _ = sb.WriteString("CASE\n")
|
||||||
|
|
||||||
|
order := 1
|
||||||
|
for _, b := range latencyHistogramBuckets {
|
||||||
|
if b.upperMs <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order))
|
||||||
|
order++
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order))
|
||||||
|
_, _ = sb.WriteString("END")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLatencyHistogramBuckets_AreConsistent(t *testing.T) {
|
||||||
|
require.Equal(t, len(latencyHistogramBuckets), len(latencyHistogramOrderedRanges))
|
||||||
|
for i, b := range latencyHistogramBuckets {
|
||||||
|
require.Equal(t, b.label, latencyHistogramOrderedRanges[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
422
backend/internal/repository/ops_repo_metrics.go
Normal file
422
backend/internal/repository/ops_repo_metrics.go
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) InsertSystemMetrics(ctx context.Context, input *service.OpsInsertSystemMetricsInput) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
|
||||||
|
window := input.WindowMinutes
|
||||||
|
if window <= 0 {
|
||||||
|
window = 1
|
||||||
|
}
|
||||||
|
createdAt := input.CreatedAt
|
||||||
|
if createdAt.IsZero() {
|
||||||
|
createdAt = time.Now().UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_system_metrics (
|
||||||
|
created_at,
|
||||||
|
window_minutes,
|
||||||
|
platform,
|
||||||
|
group_id,
|
||||||
|
|
||||||
|
success_count,
|
||||||
|
error_count_total,
|
||||||
|
business_limited_count,
|
||||||
|
error_count_sla,
|
||||||
|
|
||||||
|
upstream_error_count_excl_429_529,
|
||||||
|
upstream_429_count,
|
||||||
|
upstream_529_count,
|
||||||
|
|
||||||
|
token_consumed,
|
||||||
|
qps,
|
||||||
|
tps,
|
||||||
|
|
||||||
|
duration_p50_ms,
|
||||||
|
duration_p90_ms,
|
||||||
|
duration_p95_ms,
|
||||||
|
duration_p99_ms,
|
||||||
|
duration_avg_ms,
|
||||||
|
duration_max_ms,
|
||||||
|
|
||||||
|
ttft_p50_ms,
|
||||||
|
ttft_p90_ms,
|
||||||
|
ttft_p95_ms,
|
||||||
|
ttft_p99_ms,
|
||||||
|
ttft_avg_ms,
|
||||||
|
ttft_max_ms,
|
||||||
|
|
||||||
|
cpu_usage_percent,
|
||||||
|
memory_used_mb,
|
||||||
|
memory_total_mb,
|
||||||
|
memory_usage_percent,
|
||||||
|
|
||||||
|
db_ok,
|
||||||
|
redis_ok,
|
||||||
|
|
||||||
|
redis_conn_total,
|
||||||
|
redis_conn_idle,
|
||||||
|
|
||||||
|
db_conn_active,
|
||||||
|
db_conn_idle,
|
||||||
|
db_conn_waiting,
|
||||||
|
|
||||||
|
goroutine_count,
|
||||||
|
concurrency_queue_depth
|
||||||
|
) VALUES (
|
||||||
|
$1,$2,$3,$4,
|
||||||
|
$5,$6,$7,$8,
|
||||||
|
$9,$10,$11,
|
||||||
|
$12,$13,$14,
|
||||||
|
$15,$16,$17,$18,$19,$20,
|
||||||
|
$21,$22,$23,$24,$25,$26,
|
||||||
|
$27,$28,$29,$30,
|
||||||
|
$31,$32,
|
||||||
|
$33,$34,
|
||||||
|
$35,$36,$37,
|
||||||
|
$38,$39
|
||||||
|
)`
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
createdAt,
|
||||||
|
window,
|
||||||
|
opsNullString(input.Platform),
|
||||||
|
opsNullInt64(input.GroupID),
|
||||||
|
|
||||||
|
input.SuccessCount,
|
||||||
|
input.ErrorCountTotal,
|
||||||
|
input.BusinessLimitedCount,
|
||||||
|
input.ErrorCountSLA,
|
||||||
|
|
||||||
|
input.UpstreamErrorCountExcl429529,
|
||||||
|
input.Upstream429Count,
|
||||||
|
input.Upstream529Count,
|
||||||
|
|
||||||
|
input.TokenConsumed,
|
||||||
|
opsNullFloat64(input.QPS),
|
||||||
|
opsNullFloat64(input.TPS),
|
||||||
|
|
||||||
|
opsNullInt(input.DurationP50Ms),
|
||||||
|
opsNullInt(input.DurationP90Ms),
|
||||||
|
opsNullInt(input.DurationP95Ms),
|
||||||
|
opsNullInt(input.DurationP99Ms),
|
||||||
|
opsNullFloat64(input.DurationAvgMs),
|
||||||
|
opsNullInt(input.DurationMaxMs),
|
||||||
|
|
||||||
|
opsNullInt(input.TTFTP50Ms),
|
||||||
|
opsNullInt(input.TTFTP90Ms),
|
||||||
|
opsNullInt(input.TTFTP95Ms),
|
||||||
|
opsNullInt(input.TTFTP99Ms),
|
||||||
|
opsNullFloat64(input.TTFTAvgMs),
|
||||||
|
opsNullInt(input.TTFTMaxMs),
|
||||||
|
|
||||||
|
opsNullFloat64(input.CPUUsagePercent),
|
||||||
|
opsNullInt(input.MemoryUsedMB),
|
||||||
|
opsNullInt(input.MemoryTotalMB),
|
||||||
|
opsNullFloat64(input.MemoryUsagePercent),
|
||||||
|
|
||||||
|
opsNullBool(input.DBOK),
|
||||||
|
opsNullBool(input.RedisOK),
|
||||||
|
|
||||||
|
opsNullInt(input.RedisConnTotal),
|
||||||
|
opsNullInt(input.RedisConnIdle),
|
||||||
|
|
||||||
|
opsNullInt(input.DBConnActive),
|
||||||
|
opsNullInt(input.DBConnIdle),
|
||||||
|
opsNullInt(input.DBConnWaiting),
|
||||||
|
|
||||||
|
opsNullInt(input.GoroutineCount),
|
||||||
|
opsNullInt(input.ConcurrencyQueueDepth),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*service.OpsSystemMetricsSnapshot, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if windowMinutes <= 0 {
|
||||||
|
windowMinutes = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
created_at,
|
||||||
|
window_minutes,
|
||||||
|
|
||||||
|
cpu_usage_percent,
|
||||||
|
memory_used_mb,
|
||||||
|
memory_total_mb,
|
||||||
|
memory_usage_percent,
|
||||||
|
|
||||||
|
db_ok,
|
||||||
|
redis_ok,
|
||||||
|
|
||||||
|
redis_conn_total,
|
||||||
|
redis_conn_idle,
|
||||||
|
|
||||||
|
db_conn_active,
|
||||||
|
db_conn_idle,
|
||||||
|
db_conn_waiting,
|
||||||
|
|
||||||
|
goroutine_count,
|
||||||
|
concurrency_queue_depth
|
||||||
|
FROM ops_system_metrics
|
||||||
|
WHERE window_minutes = $1
|
||||||
|
AND platform IS NULL
|
||||||
|
AND group_id IS NULL
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT 1`
|
||||||
|
|
||||||
|
var out service.OpsSystemMetricsSnapshot
|
||||||
|
var cpu sql.NullFloat64
|
||||||
|
var memUsed sql.NullInt64
|
||||||
|
var memTotal sql.NullInt64
|
||||||
|
var memPct sql.NullFloat64
|
||||||
|
var dbOK sql.NullBool
|
||||||
|
var redisOK sql.NullBool
|
||||||
|
var redisTotal sql.NullInt64
|
||||||
|
var redisIdle sql.NullInt64
|
||||||
|
var dbActive sql.NullInt64
|
||||||
|
var dbIdle sql.NullInt64
|
||||||
|
var dbWaiting sql.NullInt64
|
||||||
|
var goroutines sql.NullInt64
|
||||||
|
var queueDepth sql.NullInt64
|
||||||
|
|
||||||
|
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
|
||||||
|
&out.ID,
|
||||||
|
&out.CreatedAt,
|
||||||
|
&out.WindowMinutes,
|
||||||
|
&cpu,
|
||||||
|
&memUsed,
|
||||||
|
&memTotal,
|
||||||
|
&memPct,
|
||||||
|
&dbOK,
|
||||||
|
&redisOK,
|
||||||
|
&redisTotal,
|
||||||
|
&redisIdle,
|
||||||
|
&dbActive,
|
||||||
|
&dbIdle,
|
||||||
|
&dbWaiting,
|
||||||
|
&goroutines,
|
||||||
|
&queueDepth,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cpu.Valid {
|
||||||
|
v := cpu.Float64
|
||||||
|
out.CPUUsagePercent = &v
|
||||||
|
}
|
||||||
|
if memUsed.Valid {
|
||||||
|
v := memUsed.Int64
|
||||||
|
out.MemoryUsedMB = &v
|
||||||
|
}
|
||||||
|
if memTotal.Valid {
|
||||||
|
v := memTotal.Int64
|
||||||
|
out.MemoryTotalMB = &v
|
||||||
|
}
|
||||||
|
if memPct.Valid {
|
||||||
|
v := memPct.Float64
|
||||||
|
out.MemoryUsagePercent = &v
|
||||||
|
}
|
||||||
|
if dbOK.Valid {
|
||||||
|
v := dbOK.Bool
|
||||||
|
out.DBOK = &v
|
||||||
|
}
|
||||||
|
if redisOK.Valid {
|
||||||
|
v := redisOK.Bool
|
||||||
|
out.RedisOK = &v
|
||||||
|
}
|
||||||
|
if redisTotal.Valid {
|
||||||
|
v := int(redisTotal.Int64)
|
||||||
|
out.RedisConnTotal = &v
|
||||||
|
}
|
||||||
|
if redisIdle.Valid {
|
||||||
|
v := int(redisIdle.Int64)
|
||||||
|
out.RedisConnIdle = &v
|
||||||
|
}
|
||||||
|
if dbActive.Valid {
|
||||||
|
v := int(dbActive.Int64)
|
||||||
|
out.DBConnActive = &v
|
||||||
|
}
|
||||||
|
if dbIdle.Valid {
|
||||||
|
v := int(dbIdle.Int64)
|
||||||
|
out.DBConnIdle = &v
|
||||||
|
}
|
||||||
|
if dbWaiting.Valid {
|
||||||
|
v := int(dbWaiting.Int64)
|
||||||
|
out.DBConnWaiting = &v
|
||||||
|
}
|
||||||
|
if goroutines.Valid {
|
||||||
|
v := int(goroutines.Int64)
|
||||||
|
out.GoroutineCount = &v
|
||||||
|
}
|
||||||
|
if queueDepth.Valid {
|
||||||
|
v := int(queueDepth.Int64)
|
||||||
|
out.ConcurrencyQueueDepth = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) UpsertJobHeartbeat(ctx context.Context, input *service.OpsUpsertJobHeartbeatInput) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if input == nil {
|
||||||
|
return fmt.Errorf("nil input")
|
||||||
|
}
|
||||||
|
if input.JobName == "" {
|
||||||
|
return fmt.Errorf("job_name required")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_job_heartbeats (
|
||||||
|
job_name,
|
||||||
|
last_run_at,
|
||||||
|
last_success_at,
|
||||||
|
last_error_at,
|
||||||
|
last_error,
|
||||||
|
last_duration_ms,
|
||||||
|
updated_at
|
||||||
|
) VALUES (
|
||||||
|
$1,$2,$3,$4,$5,$6,NOW()
|
||||||
|
)
|
||||||
|
ON CONFLICT (job_name) DO UPDATE SET
|
||||||
|
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
|
||||||
|
last_success_at = COALESCE(EXCLUDED.last_success_at, ops_job_heartbeats.last_success_at),
|
||||||
|
last_error_at = CASE
|
||||||
|
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
|
||||||
|
ELSE COALESCE(EXCLUDED.last_error_at, ops_job_heartbeats.last_error_at)
|
||||||
|
END,
|
||||||
|
last_error = CASE
|
||||||
|
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
|
||||||
|
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
|
||||||
|
END,
|
||||||
|
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
|
||||||
|
updated_at = NOW()`
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
input.JobName,
|
||||||
|
opsNullTime(input.LastRunAt),
|
||||||
|
opsNullTime(input.LastSuccessAt),
|
||||||
|
opsNullTime(input.LastErrorAt),
|
||||||
|
opsNullString(input.LastError),
|
||||||
|
opsNullInt(input.LastDurationMs),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) ListJobHeartbeats(ctx context.Context) ([]*service.OpsJobHeartbeat, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
job_name,
|
||||||
|
last_run_at,
|
||||||
|
last_success_at,
|
||||||
|
last_error_at,
|
||||||
|
last_error,
|
||||||
|
last_duration_ms,
|
||||||
|
updated_at
|
||||||
|
FROM ops_job_heartbeats
|
||||||
|
ORDER BY job_name ASC`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
out := make([]*service.OpsJobHeartbeat, 0, 8)
|
||||||
|
for rows.Next() {
|
||||||
|
var item service.OpsJobHeartbeat
|
||||||
|
var lastRun sql.NullTime
|
||||||
|
var lastSuccess sql.NullTime
|
||||||
|
var lastErrorAt sql.NullTime
|
||||||
|
var lastError sql.NullString
|
||||||
|
var lastDuration sql.NullInt64
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&item.JobName,
|
||||||
|
&lastRun,
|
||||||
|
&lastSuccess,
|
||||||
|
&lastErrorAt,
|
||||||
|
&lastError,
|
||||||
|
&lastDuration,
|
||||||
|
&item.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastRun.Valid {
|
||||||
|
v := lastRun.Time
|
||||||
|
item.LastRunAt = &v
|
||||||
|
}
|
||||||
|
if lastSuccess.Valid {
|
||||||
|
v := lastSuccess.Time
|
||||||
|
item.LastSuccessAt = &v
|
||||||
|
}
|
||||||
|
if lastErrorAt.Valid {
|
||||||
|
v := lastErrorAt.Time
|
||||||
|
item.LastErrorAt = &v
|
||||||
|
}
|
||||||
|
if lastError.Valid {
|
||||||
|
v := lastError.String
|
||||||
|
item.LastError = &v
|
||||||
|
}
|
||||||
|
if lastDuration.Valid {
|
||||||
|
v := lastDuration.Int64
|
||||||
|
item.LastDurationMs = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsNullBool(v *bool) any {
|
||||||
|
if v == nil {
|
||||||
|
return sql.NullBool{}
|
||||||
|
}
|
||||||
|
return sql.NullBool{Bool: *v, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsNullFloat64(v *float64) any {
|
||||||
|
if v == nil {
|
||||||
|
return sql.NullFloat64{}
|
||||||
|
}
|
||||||
|
return sql.NullFloat64{Float64: *v, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsNullTime(v *time.Time) any {
|
||||||
|
if v == nil || v.IsZero() {
|
||||||
|
return sql.NullTime{}
|
||||||
|
}
|
||||||
|
return sql.NullTime{Time: *v, Valid: true}
|
||||||
|
}
|
||||||
361
backend/internal/repository/ops_repo_preagg.go
Normal file
361
backend/internal/repository/ops_repo_preagg.go
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := startTime.UTC()
|
||||||
|
end := endTime.UTC()
|
||||||
|
|
||||||
|
// NOTE:
|
||||||
|
// - We aggregate usage_logs + ops_error_logs into ops_metrics_hourly.
|
||||||
|
// - We emit three dimension granularities via GROUPING SETS:
|
||||||
|
// 1) overall: (bucket_start)
|
||||||
|
// 2) platform: (bucket_start, platform)
|
||||||
|
// 3) group: (bucket_start, platform, group_id)
|
||||||
|
//
|
||||||
|
// IMPORTANT: Postgres UNIQUE treats NULLs as distinct, so the table uses a COALESCE-based
|
||||||
|
// unique index; our ON CONFLICT target must match that expression set.
|
||||||
|
q := `
|
||||||
|
WITH usage_base AS (
|
||||||
|
SELECT
|
||||||
|
date_trunc('hour', ul.created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||||
|
g.platform AS platform,
|
||||||
|
ul.group_id AS group_id,
|
||||||
|
ul.duration_ms AS duration_ms,
|
||||||
|
ul.first_token_ms AS first_token_ms,
|
||||||
|
(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) AS tokens
|
||||||
|
FROM usage_logs ul
|
||||||
|
JOIN groups g ON g.id = ul.group_id
|
||||||
|
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||||
|
),
|
||||||
|
usage_agg AS (
|
||||||
|
SELECT
|
||||||
|
bucket_start,
|
||||||
|
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
|
||||||
|
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
|
||||||
|
COUNT(*) AS success_count,
|
||||||
|
COALESCE(SUM(tokens), 0) AS token_consumed,
|
||||||
|
|
||||||
|
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50_ms,
|
||||||
|
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90_ms,
|
||||||
|
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95_ms,
|
||||||
|
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99_ms,
|
||||||
|
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg_ms,
|
||||||
|
MAX(duration_ms) AS duration_max_ms,
|
||||||
|
|
||||||
|
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50_ms,
|
||||||
|
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90_ms,
|
||||||
|
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95_ms,
|
||||||
|
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99_ms,
|
||||||
|
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg_ms,
|
||||||
|
MAX(first_token_ms) AS ttft_max_ms
|
||||||
|
FROM usage_base
|
||||||
|
GROUP BY GROUPING SETS (
|
||||||
|
(bucket_start),
|
||||||
|
(bucket_start, platform),
|
||||||
|
(bucket_start, platform, group_id)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
error_base AS (
|
||||||
|
SELECT
|
||||||
|
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||||
|
platform AS platform,
|
||||||
|
group_id AS group_id,
|
||||||
|
is_business_limited AS is_business_limited,
|
||||||
|
error_owner AS error_owner,
|
||||||
|
status_code AS client_status_code,
|
||||||
|
COALESCE(upstream_status_code, status_code, 0) AS effective_status_code
|
||||||
|
FROM ops_error_logs
|
||||||
|
-- Exclude count_tokens requests from error metrics as they are informational probes
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
AND is_count_tokens = FALSE
|
||||||
|
),
|
||||||
|
error_agg AS (
|
||||||
|
SELECT
|
||||||
|
bucket_start,
|
||||||
|
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
|
||||||
|
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
|
||||||
|
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400) AS error_count_total,
|
||||||
|
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND is_business_limited) AS business_limited_count,
|
||||||
|
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND NOT is_business_limited) AS error_count_sla,
|
||||||
|
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) NOT IN (429, 529)) AS upstream_error_count_excl_429_529,
|
||||||
|
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 429) AS upstream_429_count,
|
||||||
|
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 529) AS upstream_529_count
|
||||||
|
FROM error_base
|
||||||
|
GROUP BY GROUPING SETS (
|
||||||
|
(bucket_start),
|
||||||
|
(bucket_start, platform),
|
||||||
|
(bucket_start, platform, group_id)
|
||||||
|
)
|
||||||
|
HAVING GROUPING(group_id) = 1 OR group_id IS NOT NULL
|
||||||
|
),
|
||||||
|
combined AS (
|
||||||
|
SELECT
|
||||||
|
COALESCE(u.bucket_start, e.bucket_start) AS bucket_start,
|
||||||
|
COALESCE(u.platform, e.platform) AS platform,
|
||||||
|
COALESCE(u.group_id, e.group_id) AS group_id,
|
||||||
|
|
||||||
|
COALESCE(u.success_count, 0) AS success_count,
|
||||||
|
COALESCE(e.error_count_total, 0) AS error_count_total,
|
||||||
|
COALESCE(e.business_limited_count, 0) AS business_limited_count,
|
||||||
|
COALESCE(e.error_count_sla, 0) AS error_count_sla,
|
||||||
|
COALESCE(e.upstream_error_count_excl_429_529, 0) AS upstream_error_count_excl_429_529,
|
||||||
|
COALESCE(e.upstream_429_count, 0) AS upstream_429_count,
|
||||||
|
COALESCE(e.upstream_529_count, 0) AS upstream_529_count,
|
||||||
|
|
||||||
|
COALESCE(u.token_consumed, 0) AS token_consumed,
|
||||||
|
|
||||||
|
u.duration_p50_ms,
|
||||||
|
u.duration_p90_ms,
|
||||||
|
u.duration_p95_ms,
|
||||||
|
u.duration_p99_ms,
|
||||||
|
u.duration_avg_ms,
|
||||||
|
u.duration_max_ms,
|
||||||
|
|
||||||
|
u.ttft_p50_ms,
|
||||||
|
u.ttft_p90_ms,
|
||||||
|
u.ttft_p95_ms,
|
||||||
|
u.ttft_p99_ms,
|
||||||
|
u.ttft_avg_ms,
|
||||||
|
u.ttft_max_ms
|
||||||
|
FROM usage_agg u
|
||||||
|
FULL OUTER JOIN error_agg e
|
||||||
|
ON u.bucket_start = e.bucket_start
|
||||||
|
AND COALESCE(u.platform, '') = COALESCE(e.platform, '')
|
||||||
|
AND COALESCE(u.group_id, 0) = COALESCE(e.group_id, 0)
|
||||||
|
)
|
||||||
|
INSERT INTO ops_metrics_hourly (
|
||||||
|
bucket_start,
|
||||||
|
platform,
|
||||||
|
group_id,
|
||||||
|
success_count,
|
||||||
|
error_count_total,
|
||||||
|
business_limited_count,
|
||||||
|
error_count_sla,
|
||||||
|
upstream_error_count_excl_429_529,
|
||||||
|
upstream_429_count,
|
||||||
|
upstream_529_count,
|
||||||
|
token_consumed,
|
||||||
|
duration_p50_ms,
|
||||||
|
duration_p90_ms,
|
||||||
|
duration_p95_ms,
|
||||||
|
duration_p99_ms,
|
||||||
|
duration_avg_ms,
|
||||||
|
duration_max_ms,
|
||||||
|
ttft_p50_ms,
|
||||||
|
ttft_p90_ms,
|
||||||
|
ttft_p95_ms,
|
||||||
|
ttft_p99_ms,
|
||||||
|
ttft_avg_ms,
|
||||||
|
ttft_max_ms,
|
||||||
|
computed_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
bucket_start,
|
||||||
|
NULLIF(platform, '') AS platform,
|
||||||
|
group_id,
|
||||||
|
success_count,
|
||||||
|
error_count_total,
|
||||||
|
business_limited_count,
|
||||||
|
error_count_sla,
|
||||||
|
upstream_error_count_excl_429_529,
|
||||||
|
upstream_429_count,
|
||||||
|
upstream_529_count,
|
||||||
|
token_consumed,
|
||||||
|
duration_p50_ms::int,
|
||||||
|
duration_p90_ms::int,
|
||||||
|
duration_p95_ms::int,
|
||||||
|
duration_p99_ms::int,
|
||||||
|
duration_avg_ms,
|
||||||
|
duration_max_ms::int,
|
||||||
|
ttft_p50_ms::int,
|
||||||
|
ttft_p90_ms::int,
|
||||||
|
ttft_p95_ms::int,
|
||||||
|
ttft_p99_ms::int,
|
||||||
|
ttft_avg_ms,
|
||||||
|
ttft_max_ms::int,
|
||||||
|
NOW()
|
||||||
|
FROM combined
|
||||||
|
WHERE bucket_start IS NOT NULL
|
||||||
|
AND (platform IS NULL OR platform <> '')
|
||||||
|
ON CONFLICT (bucket_start, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
|
||||||
|
success_count = EXCLUDED.success_count,
|
||||||
|
error_count_total = EXCLUDED.error_count_total,
|
||||||
|
business_limited_count = EXCLUDED.business_limited_count,
|
||||||
|
error_count_sla = EXCLUDED.error_count_sla,
|
||||||
|
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
|
||||||
|
upstream_429_count = EXCLUDED.upstream_429_count,
|
||||||
|
upstream_529_count = EXCLUDED.upstream_529_count,
|
||||||
|
token_consumed = EXCLUDED.token_consumed,
|
||||||
|
|
||||||
|
duration_p50_ms = EXCLUDED.duration_p50_ms,
|
||||||
|
duration_p90_ms = EXCLUDED.duration_p90_ms,
|
||||||
|
duration_p95_ms = EXCLUDED.duration_p95_ms,
|
||||||
|
duration_p99_ms = EXCLUDED.duration_p99_ms,
|
||||||
|
duration_avg_ms = EXCLUDED.duration_avg_ms,
|
||||||
|
duration_max_ms = EXCLUDED.duration_max_ms,
|
||||||
|
|
||||||
|
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
|
||||||
|
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
|
||||||
|
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
|
||||||
|
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
|
||||||
|
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
|
||||||
|
ttft_max_ms = EXCLUDED.ttft_max_ms,
|
||||||
|
|
||||||
|
computed_at = NOW()
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(ctx, q, start, end)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := startTime.UTC()
|
||||||
|
end := endTime.UTC()
|
||||||
|
|
||||||
|
q := `
|
||||||
|
INSERT INTO ops_metrics_daily (
|
||||||
|
bucket_date,
|
||||||
|
platform,
|
||||||
|
group_id,
|
||||||
|
success_count,
|
||||||
|
error_count_total,
|
||||||
|
business_limited_count,
|
||||||
|
error_count_sla,
|
||||||
|
upstream_error_count_excl_429_529,
|
||||||
|
upstream_429_count,
|
||||||
|
upstream_529_count,
|
||||||
|
token_consumed,
|
||||||
|
duration_p50_ms,
|
||||||
|
duration_p90_ms,
|
||||||
|
duration_p95_ms,
|
||||||
|
duration_p99_ms,
|
||||||
|
duration_avg_ms,
|
||||||
|
duration_max_ms,
|
||||||
|
ttft_p50_ms,
|
||||||
|
ttft_p90_ms,
|
||||||
|
ttft_p95_ms,
|
||||||
|
ttft_p99_ms,
|
||||||
|
ttft_avg_ms,
|
||||||
|
ttft_max_ms,
|
||||||
|
computed_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||||
|
platform,
|
||||||
|
group_id,
|
||||||
|
|
||||||
|
COALESCE(SUM(success_count), 0) AS success_count,
|
||||||
|
COALESCE(SUM(error_count_total), 0) AS error_count_total,
|
||||||
|
COALESCE(SUM(business_limited_count), 0) AS business_limited_count,
|
||||||
|
COALESCE(SUM(error_count_sla), 0) AS error_count_sla,
|
||||||
|
COALESCE(SUM(upstream_error_count_excl_429_529), 0) AS upstream_error_count_excl_429_529,
|
||||||
|
COALESCE(SUM(upstream_429_count), 0) AS upstream_429_count,
|
||||||
|
COALESCE(SUM(upstream_529_count), 0) AS upstream_529_count,
|
||||||
|
COALESCE(SUM(token_consumed), 0) AS token_consumed,
|
||||||
|
|
||||||
|
-- Approximation: weighted average for p50/p90, max for p95/p99 (conservative tail).
|
||||||
|
ROUND(SUM(duration_p50_ms::double precision * success_count) FILTER (WHERE duration_p50_ms IS NOT NULL)
|
||||||
|
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p50_ms IS NOT NULL), 0))::int AS duration_p50_ms,
|
||||||
|
ROUND(SUM(duration_p90_ms::double precision * success_count) FILTER (WHERE duration_p90_ms IS NOT NULL)
|
||||||
|
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p90_ms IS NOT NULL), 0))::int AS duration_p90_ms,
|
||||||
|
MAX(duration_p95_ms) AS duration_p95_ms,
|
||||||
|
MAX(duration_p99_ms) AS duration_p99_ms,
|
||||||
|
SUM(duration_avg_ms * success_count) FILTER (WHERE duration_avg_ms IS NOT NULL)
|
||||||
|
/ NULLIF(SUM(success_count) FILTER (WHERE duration_avg_ms IS NOT NULL), 0) AS duration_avg_ms,
|
||||||
|
MAX(duration_max_ms) AS duration_max_ms,
|
||||||
|
|
||||||
|
ROUND(SUM(ttft_p50_ms::double precision * success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL)
|
||||||
|
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL), 0))::int AS ttft_p50_ms,
|
||||||
|
ROUND(SUM(ttft_p90_ms::double precision * success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL)
|
||||||
|
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL), 0))::int AS ttft_p90_ms,
|
||||||
|
MAX(ttft_p95_ms) AS ttft_p95_ms,
|
||||||
|
MAX(ttft_p99_ms) AS ttft_p99_ms,
|
||||||
|
SUM(ttft_avg_ms * success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL)
|
||||||
|
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL), 0) AS ttft_avg_ms,
|
||||||
|
MAX(ttft_max_ms) AS ttft_max_ms,
|
||||||
|
|
||||||
|
NOW()
|
||||||
|
FROM ops_metrics_hourly
|
||||||
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
|
GROUP BY 1, 2, 3
|
||||||
|
ON CONFLICT (bucket_date, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
|
||||||
|
success_count = EXCLUDED.success_count,
|
||||||
|
error_count_total = EXCLUDED.error_count_total,
|
||||||
|
business_limited_count = EXCLUDED.business_limited_count,
|
||||||
|
error_count_sla = EXCLUDED.error_count_sla,
|
||||||
|
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
|
||||||
|
upstream_429_count = EXCLUDED.upstream_429_count,
|
||||||
|
upstream_529_count = EXCLUDED.upstream_529_count,
|
||||||
|
token_consumed = EXCLUDED.token_consumed,
|
||||||
|
|
||||||
|
duration_p50_ms = EXCLUDED.duration_p50_ms,
|
||||||
|
duration_p90_ms = EXCLUDED.duration_p90_ms,
|
||||||
|
duration_p95_ms = EXCLUDED.duration_p95_ms,
|
||||||
|
duration_p99_ms = EXCLUDED.duration_p99_ms,
|
||||||
|
duration_avg_ms = EXCLUDED.duration_avg_ms,
|
||||||
|
duration_max_ms = EXCLUDED.duration_max_ms,
|
||||||
|
|
||||||
|
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
|
||||||
|
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
|
||||||
|
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
|
||||||
|
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
|
||||||
|
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
|
||||||
|
ttft_max_ms = EXCLUDED.ttft_max_ms,
|
||||||
|
|
||||||
|
computed_at = NOW()
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := r.db.ExecContext(ctx, q, start, end)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return time.Time{}, false, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
|
||||||
|
var value sql.NullTime
|
||||||
|
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_start) FROM ops_metrics_hourly`).Scan(&value); err != nil {
|
||||||
|
return time.Time{}, false, err
|
||||||
|
}
|
||||||
|
if !value.Valid {
|
||||||
|
return time.Time{}, false, nil
|
||||||
|
}
|
||||||
|
return value.Time.UTC(), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return time.Time{}, false, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
|
||||||
|
var value sql.NullTime
|
||||||
|
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_date) FROM ops_metrics_daily`).Scan(&value); err != nil {
|
||||||
|
return time.Time{}, false, err
|
||||||
|
}
|
||||||
|
if !value.Valid {
|
||||||
|
return time.Time{}, false, nil
|
||||||
|
}
|
||||||
|
t := value.Time.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC), true, nil
|
||||||
|
}
|
||||||
129
backend/internal/repository/ops_repo_realtime_traffic.go
Normal file
129
backend/internal/repository/ops_repo_realtime_traffic.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) GetRealtimeTrafficSummary(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsRealtimeTrafficSummary, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
return nil, fmt.Errorf("nil filter")
|
||||||
|
}
|
||||||
|
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||||
|
return nil, fmt.Errorf("start_time/end_time required")
|
||||||
|
}
|
||||||
|
|
||||||
|
start := filter.StartTime.UTC()
|
||||||
|
end := filter.EndTime.UTC()
|
||||||
|
if start.After(end) {
|
||||||
|
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||||
|
}
|
||||||
|
|
||||||
|
window := end.Sub(start)
|
||||||
|
if window <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid time window")
|
||||||
|
}
|
||||||
|
if window > time.Hour {
|
||||||
|
return nil, fmt.Errorf("window too large")
|
||||||
|
}
|
||||||
|
|
||||||
|
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||||
|
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||||
|
|
||||||
|
q := `
|
||||||
|
WITH usage_buckets AS (
|
||||||
|
SELECT
|
||||||
|
date_trunc('minute', ul.created_at) AS bucket,
|
||||||
|
COALESCE(COUNT(*), 0) AS success_count,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_sum
|
||||||
|
FROM usage_logs ul
|
||||||
|
` + usageJoin + `
|
||||||
|
` + usageWhere + `
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
error_buckets AS (
|
||||||
|
SELECT
|
||||||
|
date_trunc('minute', created_at) AS bucket,
|
||||||
|
COALESCE(COUNT(*), 0) AS error_count
|
||||||
|
FROM ops_error_logs
|
||||||
|
` + errorWhere + `
|
||||||
|
AND COALESCE(status_code, 0) >= 400
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
combined AS (
|
||||||
|
SELECT
|
||||||
|
COALESCE(u.bucket, e.bucket) AS bucket,
|
||||||
|
COALESCE(u.success_count, 0) AS success_count,
|
||||||
|
COALESCE(u.token_sum, 0) AS token_sum,
|
||||||
|
COALESCE(e.error_count, 0) AS error_count,
|
||||||
|
COALESCE(u.success_count, 0) + COALESCE(e.error_count, 0) AS request_total
|
||||||
|
FROM usage_buckets u
|
||||||
|
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
COALESCE(SUM(success_count), 0) AS success_total,
|
||||||
|
COALESCE(SUM(error_count), 0) AS error_total,
|
||||||
|
COALESCE(SUM(token_sum), 0) AS token_total,
|
||||||
|
COALESCE(MAX(request_total), 0) AS peak_requests_per_min,
|
||||||
|
COALESCE(MAX(token_sum), 0) AS peak_tokens_per_min
|
||||||
|
FROM combined`
|
||||||
|
|
||||||
|
args := append(usageArgs, errorArgs...)
|
||||||
|
var successCount int64
|
||||||
|
var errorTotal int64
|
||||||
|
var tokenConsumed int64
|
||||||
|
var peakRequestsPerMin int64
|
||||||
|
var peakTokensPerMin int64
|
||||||
|
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
|
||||||
|
&successCount,
|
||||||
|
&errorTotal,
|
||||||
|
&tokenConsumed,
|
||||||
|
&peakRequestsPerMin,
|
||||||
|
&peakTokensPerMin,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
windowSeconds := window.Seconds()
|
||||||
|
if windowSeconds <= 0 {
|
||||||
|
windowSeconds = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCountTotal := successCount + errorTotal
|
||||||
|
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
|
||||||
|
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
|
||||||
|
|
||||||
|
// Keep "current" consistent with the dashboard overview semantics: last 1 minute.
|
||||||
|
// This remains "within the selected window" since end=start+window.
|
||||||
|
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
qpsPeak := roundTo1DP(float64(peakRequestsPerMin) / 60.0)
|
||||||
|
tpsPeak := roundTo1DP(float64(peakTokensPerMin) / 60.0)
|
||||||
|
|
||||||
|
return &service.OpsRealtimeTrafficSummary{
|
||||||
|
StartTime: start,
|
||||||
|
EndTime: end,
|
||||||
|
Platform: strings.TrimSpace(filter.Platform),
|
||||||
|
GroupID: filter.GroupID,
|
||||||
|
QPS: service.OpsRateSummary{
|
||||||
|
Current: qpsCurrent,
|
||||||
|
Peak: qpsPeak,
|
||||||
|
Avg: qpsAvg,
|
||||||
|
},
|
||||||
|
TPS: service.OpsRateSummary{
|
||||||
|
Current: tpsCurrent,
|
||||||
|
Peak: tpsPeak,
|
||||||
|
Avg: tpsAvg,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
286
backend/internal/repository/ops_repo_request_details.go
Normal file
286
backend/internal/repository/ops_repo_request_details.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) ListRequestDetails(ctx context.Context, filter *service.OpsRequestDetailFilter) ([]*service.OpsRequestDetail, int64, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, 0, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
|
||||||
|
page, pageSize, startTime, endTime := filter.Normalize()
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
|
||||||
|
conditions := make([]string, 0, 16)
|
||||||
|
args := make([]any, 0, 24)
|
||||||
|
|
||||||
|
// Placeholders $1/$2 reserved for time window inside the CTE.
|
||||||
|
args = append(args, startTime.UTC(), endTime.UTC())
|
||||||
|
|
||||||
|
addCondition := func(condition string, values ...any) {
|
||||||
|
conditions = append(conditions, condition)
|
||||||
|
args = append(args, values...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter != nil {
|
||||||
|
if kind := strings.TrimSpace(strings.ToLower(filter.Kind)); kind != "" && kind != "all" {
|
||||||
|
if kind != string(service.OpsRequestKindSuccess) && kind != string(service.OpsRequestKindError) {
|
||||||
|
return nil, 0, fmt.Errorf("invalid kind")
|
||||||
|
}
|
||||||
|
addCondition(fmt.Sprintf("kind = $%d", len(args)+1), kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
if platform := strings.TrimSpace(strings.ToLower(filter.Platform)); platform != "" {
|
||||||
|
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), platform)
|
||||||
|
}
|
||||||
|
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||||
|
addCondition(fmt.Sprintf("group_id = $%d", len(args)+1), *filter.GroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.UserID != nil && *filter.UserID > 0 {
|
||||||
|
addCondition(fmt.Sprintf("user_id = $%d", len(args)+1), *filter.UserID)
|
||||||
|
}
|
||||||
|
if filter.APIKeyID != nil && *filter.APIKeyID > 0 {
|
||||||
|
addCondition(fmt.Sprintf("api_key_id = $%d", len(args)+1), *filter.APIKeyID)
|
||||||
|
}
|
||||||
|
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||||
|
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if model := strings.TrimSpace(filter.Model); model != "" {
|
||||||
|
addCondition(fmt.Sprintf("model = $%d", len(args)+1), model)
|
||||||
|
}
|
||||||
|
if requestID := strings.TrimSpace(filter.RequestID); requestID != "" {
|
||||||
|
addCondition(fmt.Sprintf("request_id = $%d", len(args)+1), requestID)
|
||||||
|
}
|
||||||
|
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||||
|
like := "%" + strings.ToLower(q) + "%"
|
||||||
|
startIdx := len(args) + 1
|
||||||
|
addCondition(
|
||||||
|
fmt.Sprintf("(LOWER(COALESCE(request_id,'')) LIKE $%d OR LOWER(COALESCE(model,'')) LIKE $%d OR LOWER(COALESCE(message,'')) LIKE $%d)",
|
||||||
|
startIdx, startIdx+1, startIdx+2,
|
||||||
|
),
|
||||||
|
like, like, like,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.MinDurationMs != nil {
|
||||||
|
addCondition(fmt.Sprintf("duration_ms >= $%d", len(args)+1), *filter.MinDurationMs)
|
||||||
|
}
|
||||||
|
if filter.MaxDurationMs != nil {
|
||||||
|
addCondition(fmt.Sprintf("duration_ms <= $%d", len(args)+1), *filter.MaxDurationMs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
where := ""
|
||||||
|
if len(conditions) > 0 {
|
||||||
|
where = "WHERE " + strings.Join(conditions, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
|
cte := `
|
||||||
|
WITH combined AS (
|
||||||
|
SELECT
|
||||||
|
'success'::TEXT AS kind,
|
||||||
|
ul.created_at AS created_at,
|
||||||
|
ul.request_id AS request_id,
|
||||||
|
COALESCE(NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
|
||||||
|
ul.model AS model,
|
||||||
|
ul.duration_ms AS duration_ms,
|
||||||
|
NULL::INT AS status_code,
|
||||||
|
NULL::BIGINT AS error_id,
|
||||||
|
NULL::TEXT AS phase,
|
||||||
|
NULL::TEXT AS severity,
|
||||||
|
NULL::TEXT AS message,
|
||||||
|
ul.user_id AS user_id,
|
||||||
|
ul.api_key_id AS api_key_id,
|
||||||
|
ul.account_id AS account_id,
|
||||||
|
ul.group_id AS group_id,
|
||||||
|
ul.stream AS stream
|
||||||
|
FROM usage_logs ul
|
||||||
|
LEFT JOIN groups g ON g.id = ul.group_id
|
||||||
|
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||||
|
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
'error'::TEXT AS kind,
|
||||||
|
o.created_at AS created_at,
|
||||||
|
COALESCE(NULLIF(o.request_id,''), NULLIF(o.client_request_id,''), '') AS request_id,
|
||||||
|
COALESCE(NULLIF(o.platform, ''), NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
|
||||||
|
o.model AS model,
|
||||||
|
o.duration_ms AS duration_ms,
|
||||||
|
o.status_code AS status_code,
|
||||||
|
o.id AS error_id,
|
||||||
|
o.error_phase AS phase,
|
||||||
|
o.severity AS severity,
|
||||||
|
o.error_message AS message,
|
||||||
|
o.user_id AS user_id,
|
||||||
|
o.api_key_id AS api_key_id,
|
||||||
|
o.account_id AS account_id,
|
||||||
|
o.group_id AS group_id,
|
||||||
|
o.stream AS stream
|
||||||
|
FROM ops_error_logs o
|
||||||
|
LEFT JOIN groups g ON g.id = o.group_id
|
||||||
|
LEFT JOIN accounts a ON a.id = o.account_id
|
||||||
|
WHERE o.created_at >= $1 AND o.created_at < $2
|
||||||
|
AND COALESCE(o.status_code, 0) >= 400
|
||||||
|
)
|
||||||
|
`
|
||||||
|
|
||||||
|
countQuery := fmt.Sprintf(`%s SELECT COUNT(1) FROM combined %s`, cte, where)
|
||||||
|
var total int64
|
||||||
|
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
total = 0
|
||||||
|
} else {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort := "ORDER BY created_at DESC"
|
||||||
|
if filter != nil {
|
||||||
|
switch strings.TrimSpace(strings.ToLower(filter.Sort)) {
|
||||||
|
case "", "created_at_desc":
|
||||||
|
// default
|
||||||
|
case "duration_desc":
|
||||||
|
sort = "ORDER BY duration_ms DESC NULLS LAST, created_at DESC"
|
||||||
|
default:
|
||||||
|
return nil, 0, fmt.Errorf("invalid sort")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
listQuery := fmt.Sprintf(`
|
||||||
|
%s
|
||||||
|
SELECT
|
||||||
|
kind,
|
||||||
|
created_at,
|
||||||
|
request_id,
|
||||||
|
platform,
|
||||||
|
model,
|
||||||
|
duration_ms,
|
||||||
|
status_code,
|
||||||
|
error_id,
|
||||||
|
phase,
|
||||||
|
severity,
|
||||||
|
message,
|
||||||
|
user_id,
|
||||||
|
api_key_id,
|
||||||
|
account_id,
|
||||||
|
group_id,
|
||||||
|
stream
|
||||||
|
FROM combined
|
||||||
|
%s
|
||||||
|
%s
|
||||||
|
LIMIT $%d OFFSET $%d
|
||||||
|
`, cte, where, sort, len(args)+1, len(args)+2)
|
||||||
|
|
||||||
|
listArgs := append(append([]any{}, args...), pageSize, offset)
|
||||||
|
rows, err := r.db.QueryContext(ctx, listQuery, listArgs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
toIntPtr := func(v sql.NullInt64) *int {
|
||||||
|
if !v.Valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
i := int(v.Int64)
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
toInt64Ptr := func(v sql.NullInt64) *int64 {
|
||||||
|
if !v.Valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
i := v.Int64
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*service.OpsRequestDetail, 0, pageSize)
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
kind string
|
||||||
|
createdAt time.Time
|
||||||
|
requestID sql.NullString
|
||||||
|
platform sql.NullString
|
||||||
|
model sql.NullString
|
||||||
|
|
||||||
|
durationMs sql.NullInt64
|
||||||
|
statusCode sql.NullInt64
|
||||||
|
errorID sql.NullInt64
|
||||||
|
|
||||||
|
phase sql.NullString
|
||||||
|
severity sql.NullString
|
||||||
|
message sql.NullString
|
||||||
|
|
||||||
|
userID sql.NullInt64
|
||||||
|
apiKeyID sql.NullInt64
|
||||||
|
accountID sql.NullInt64
|
||||||
|
groupID sql.NullInt64
|
||||||
|
|
||||||
|
stream bool
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&kind,
|
||||||
|
&createdAt,
|
||||||
|
&requestID,
|
||||||
|
&platform,
|
||||||
|
&model,
|
||||||
|
&durationMs,
|
||||||
|
&statusCode,
|
||||||
|
&errorID,
|
||||||
|
&phase,
|
||||||
|
&severity,
|
||||||
|
&message,
|
||||||
|
&userID,
|
||||||
|
&apiKeyID,
|
||||||
|
&accountID,
|
||||||
|
&groupID,
|
||||||
|
&stream,
|
||||||
|
); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
item := &service.OpsRequestDetail{
|
||||||
|
Kind: service.OpsRequestKind(kind),
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
RequestID: strings.TrimSpace(requestID.String),
|
||||||
|
Platform: strings.TrimSpace(platform.String),
|
||||||
|
Model: strings.TrimSpace(model.String),
|
||||||
|
|
||||||
|
DurationMs: toIntPtr(durationMs),
|
||||||
|
StatusCode: toIntPtr(statusCode),
|
||||||
|
ErrorID: toInt64Ptr(errorID),
|
||||||
|
Phase: phase.String,
|
||||||
|
Severity: severity.String,
|
||||||
|
Message: message.String,
|
||||||
|
|
||||||
|
UserID: toInt64Ptr(userID),
|
||||||
|
APIKeyID: toInt64Ptr(apiKeyID),
|
||||||
|
AccountID: toInt64Ptr(accountID),
|
||||||
|
GroupID: toInt64Ptr(groupID),
|
||||||
|
|
||||||
|
Stream: stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.Platform == "" {
|
||||||
|
item.Platform = "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, total, nil
|
||||||
|
}
|
||||||
573
backend/internal/repository/ops_repo_trends.go
Normal file
573
backend/internal/repository/ops_repo_trends.go
Normal file
@@ -0,0 +1,573 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) GetThroughputTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsThroughputTrendResponse, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
return nil, fmt.Errorf("nil filter")
|
||||||
|
}
|
||||||
|
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||||
|
return nil, fmt.Errorf("start_time/end_time required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if bucketSeconds <= 0 {
|
||||||
|
bucketSeconds = 60
|
||||||
|
}
|
||||||
|
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
|
||||||
|
// Keep a small, predictable set of supported buckets for now.
|
||||||
|
bucketSeconds = 60
|
||||||
|
}
|
||||||
|
|
||||||
|
start := filter.StartTime.UTC()
|
||||||
|
end := filter.EndTime.UTC()
|
||||||
|
|
||||||
|
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||||
|
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||||
|
|
||||||
|
usageBucketExpr := opsBucketExprForUsage(bucketSeconds)
|
||||||
|
errorBucketExpr := opsBucketExprForError(bucketSeconds)
|
||||||
|
|
||||||
|
q := `
|
||||||
|
WITH usage_buckets AS (
|
||||||
|
SELECT ` + usageBucketExpr + ` AS bucket,
|
||||||
|
COUNT(*) AS success_count,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||||
|
FROM usage_logs ul
|
||||||
|
` + usageJoin + `
|
||||||
|
` + usageWhere + `
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
error_buckets AS (
|
||||||
|
SELECT ` + errorBucketExpr + ` AS bucket,
|
||||||
|
COUNT(*) AS error_count
|
||||||
|
FROM ops_error_logs
|
||||||
|
` + errorWhere + `
|
||||||
|
AND COALESCE(status_code, 0) >= 400
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
combined AS (
|
||||||
|
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
|
||||||
|
COALESCE(u.success_count, 0) AS success_count,
|
||||||
|
COALESCE(e.error_count, 0) AS error_count,
|
||||||
|
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||||
|
FROM usage_buckets u
|
||||||
|
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
bucket,
|
||||||
|
(success_count + error_count) AS request_count,
|
||||||
|
token_consumed
|
||||||
|
FROM combined
|
||||||
|
ORDER BY bucket ASC`
|
||||||
|
|
||||||
|
args := append(usageArgs, errorArgs...)
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
points := make([]*service.OpsThroughputTrendPoint, 0, 256)
|
||||||
|
for rows.Next() {
|
||||||
|
var bucket time.Time
|
||||||
|
var requests int64
|
||||||
|
var tokens sql.NullInt64
|
||||||
|
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tokenConsumed := int64(0)
|
||||||
|
if tokens.Valid {
|
||||||
|
tokenConsumed = tokens.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
denom := float64(bucketSeconds)
|
||||||
|
if denom <= 0 {
|
||||||
|
denom = 60
|
||||||
|
}
|
||||||
|
qps := roundTo1DP(float64(requests) / denom)
|
||||||
|
tps := roundTo1DP(float64(tokenConsumed) / denom)
|
||||||
|
|
||||||
|
points = append(points, &service.OpsThroughputTrendPoint{
|
||||||
|
BucketStart: bucket.UTC(),
|
||||||
|
RequestCount: requests,
|
||||||
|
TokenConsumed: tokenConsumed,
|
||||||
|
QPS: qps,
|
||||||
|
TPS: tps,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill missing buckets with zeros so charts render continuous timelines.
|
||||||
|
points = fillOpsThroughputBuckets(start, end, bucketSeconds, points)
|
||||||
|
|
||||||
|
var byPlatform []*service.OpsThroughputPlatformBreakdownItem
|
||||||
|
var topGroups []*service.OpsThroughputGroupBreakdownItem
|
||||||
|
|
||||||
|
platform := ""
|
||||||
|
if filter != nil {
|
||||||
|
platform = strings.TrimSpace(strings.ToLower(filter.Platform))
|
||||||
|
}
|
||||||
|
groupID := (*int64)(nil)
|
||||||
|
if filter != nil {
|
||||||
|
groupID = filter.GroupID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drilldown helpers:
|
||||||
|
// - No platform/group: totals by platform
|
||||||
|
// - Platform selected but no group: top groups in that platform
|
||||||
|
if platform == "" && (groupID == nil || *groupID <= 0) {
|
||||||
|
items, err := r.getThroughputBreakdownByPlatform(ctx, start, end)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
byPlatform = items
|
||||||
|
} else if platform != "" && (groupID == nil || *groupID <= 0) {
|
||||||
|
items, err := r.getThroughputTopGroupsByPlatform(ctx, start, end, platform, 10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
topGroups = items
|
||||||
|
}
|
||||||
|
|
||||||
|
return &service.OpsThroughputTrendResponse{
|
||||||
|
Bucket: opsBucketLabel(bucketSeconds),
|
||||||
|
Points: points,
|
||||||
|
|
||||||
|
ByPlatform: byPlatform,
|
||||||
|
TopGroups: topGroups,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) getThroughputBreakdownByPlatform(ctx context.Context, start, end time.Time) ([]*service.OpsThroughputPlatformBreakdownItem, error) {
|
||||||
|
q := `
|
||||||
|
WITH usage_totals AS (
|
||||||
|
SELECT COALESCE(NULLIF(g.platform,''), a.platform) AS platform,
|
||||||
|
COUNT(*) AS success_count,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||||
|
FROM usage_logs ul
|
||||||
|
LEFT JOIN groups g ON g.id = ul.group_id
|
||||||
|
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||||
|
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
error_totals AS (
|
||||||
|
SELECT platform,
|
||||||
|
COUNT(*) AS error_count
|
||||||
|
FROM ops_error_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
AND COALESCE(status_code, 0) >= 400
|
||||||
|
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
combined AS (
|
||||||
|
SELECT COALESCE(u.platform, e.platform) AS platform,
|
||||||
|
COALESCE(u.success_count, 0) AS success_count,
|
||||||
|
COALESCE(e.error_count, 0) AS error_count,
|
||||||
|
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||||
|
FROM usage_totals u
|
||||||
|
FULL OUTER JOIN error_totals e ON u.platform = e.platform
|
||||||
|
)
|
||||||
|
SELECT platform, (success_count + error_count) AS request_count, token_consumed
|
||||||
|
FROM combined
|
||||||
|
WHERE platform IS NOT NULL AND platform <> ''
|
||||||
|
ORDER BY request_count DESC`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, start, end)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
items := make([]*service.OpsThroughputPlatformBreakdownItem, 0, 8)
|
||||||
|
for rows.Next() {
|
||||||
|
var platform string
|
||||||
|
var requests int64
|
||||||
|
var tokens sql.NullInt64
|
||||||
|
if err := rows.Scan(&platform, &requests, &tokens); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tokenConsumed := int64(0)
|
||||||
|
if tokens.Valid {
|
||||||
|
tokenConsumed = tokens.Int64
|
||||||
|
}
|
||||||
|
items = append(items, &service.OpsThroughputPlatformBreakdownItem{
|
||||||
|
Platform: platform,
|
||||||
|
RequestCount: requests,
|
||||||
|
TokenConsumed: tokenConsumed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) getThroughputTopGroupsByPlatform(ctx context.Context, start, end time.Time, platform string, limit int) ([]*service.OpsThroughputGroupBreakdownItem, error) {
|
||||||
|
if strings.TrimSpace(platform) == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if limit <= 0 || limit > 100 {
|
||||||
|
limit = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
q := `
|
||||||
|
WITH usage_totals AS (
|
||||||
|
SELECT ul.group_id AS group_id,
|
||||||
|
g.name AS group_name,
|
||||||
|
COUNT(*) AS success_count,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||||
|
FROM usage_logs ul
|
||||||
|
JOIN groups g ON g.id = ul.group_id
|
||||||
|
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||||
|
AND g.platform = $3
|
||||||
|
GROUP BY 1, 2
|
||||||
|
),
|
||||||
|
error_totals AS (
|
||||||
|
SELECT group_id,
|
||||||
|
COUNT(*) AS error_count
|
||||||
|
FROM ops_error_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
AND platform = $3
|
||||||
|
AND group_id IS NOT NULL
|
||||||
|
AND COALESCE(status_code, 0) >= 400
|
||||||
|
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
combined AS (
|
||||||
|
SELECT COALESCE(u.group_id, e.group_id) AS group_id,
|
||||||
|
COALESCE(u.group_name, g2.name, '') AS group_name,
|
||||||
|
COALESCE(u.success_count, 0) AS success_count,
|
||||||
|
COALESCE(e.error_count, 0) AS error_count,
|
||||||
|
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||||
|
FROM usage_totals u
|
||||||
|
FULL OUTER JOIN error_totals e ON u.group_id = e.group_id
|
||||||
|
LEFT JOIN groups g2 ON g2.id = COALESCE(u.group_id, e.group_id)
|
||||||
|
)
|
||||||
|
SELECT group_id, group_name, (success_count + error_count) AS request_count, token_consumed
|
||||||
|
FROM combined
|
||||||
|
WHERE group_id IS NOT NULL
|
||||||
|
ORDER BY request_count DESC
|
||||||
|
LIMIT $4`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, start, end, platform, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
items := make([]*service.OpsThroughputGroupBreakdownItem, 0, limit)
|
||||||
|
for rows.Next() {
|
||||||
|
var groupID int64
|
||||||
|
var groupName sql.NullString
|
||||||
|
var requests int64
|
||||||
|
var tokens sql.NullInt64
|
||||||
|
if err := rows.Scan(&groupID, &groupName, &requests, &tokens); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tokenConsumed := int64(0)
|
||||||
|
if tokens.Valid {
|
||||||
|
tokenConsumed = tokens.Int64
|
||||||
|
}
|
||||||
|
name := ""
|
||||||
|
if groupName.Valid {
|
||||||
|
name = groupName.String
|
||||||
|
}
|
||||||
|
items = append(items, &service.OpsThroughputGroupBreakdownItem{
|
||||||
|
GroupID: groupID,
|
||||||
|
GroupName: name,
|
||||||
|
RequestCount: requests,
|
||||||
|
TokenConsumed: tokenConsumed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsBucketExprForUsage(bucketSeconds int) string {
|
||||||
|
switch bucketSeconds {
|
||||||
|
case 3600:
|
||||||
|
return "date_trunc('hour', ul.created_at)"
|
||||||
|
case 300:
|
||||||
|
// 5-minute buckets in UTC.
|
||||||
|
return "to_timestamp(floor(extract(epoch from ul.created_at) / 300) * 300)"
|
||||||
|
default:
|
||||||
|
return "date_trunc('minute', ul.created_at)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsBucketExprForError(bucketSeconds int) string {
|
||||||
|
switch bucketSeconds {
|
||||||
|
case 3600:
|
||||||
|
return "date_trunc('hour', created_at)"
|
||||||
|
case 300:
|
||||||
|
return "to_timestamp(floor(extract(epoch from created_at) / 300) * 300)"
|
||||||
|
default:
|
||||||
|
return "date_trunc('minute', created_at)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsBucketLabel(bucketSeconds int) string {
|
||||||
|
if bucketSeconds <= 0 {
|
||||||
|
return "1m"
|
||||||
|
}
|
||||||
|
if bucketSeconds%3600 == 0 {
|
||||||
|
h := bucketSeconds / 3600
|
||||||
|
if h <= 0 {
|
||||||
|
h = 1
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%dh", h)
|
||||||
|
}
|
||||||
|
m := bucketSeconds / 60
|
||||||
|
if m <= 0 {
|
||||||
|
m = 1
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%dm", m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsFloorToBucketStart(t time.Time, bucketSeconds int) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
if bucketSeconds <= 0 {
|
||||||
|
bucketSeconds = 60
|
||||||
|
}
|
||||||
|
secs := t.Unix()
|
||||||
|
floored := secs - (secs % int64(bucketSeconds))
|
||||||
|
return time.Unix(floored, 0).UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsThroughputTrendPoint) []*service.OpsThroughputTrendPoint {
|
||||||
|
if bucketSeconds <= 0 {
|
||||||
|
bucketSeconds = 60
|
||||||
|
}
|
||||||
|
if !start.Before(end) {
|
||||||
|
return points
|
||||||
|
}
|
||||||
|
|
||||||
|
endMinus := end.Add(-time.Nanosecond)
|
||||||
|
if endMinus.Before(start) {
|
||||||
|
return points
|
||||||
|
}
|
||||||
|
|
||||||
|
first := opsFloorToBucketStart(start, bucketSeconds)
|
||||||
|
last := opsFloorToBucketStart(endMinus, bucketSeconds)
|
||||||
|
step := time.Duration(bucketSeconds) * time.Second
|
||||||
|
|
||||||
|
existing := make(map[int64]*service.OpsThroughputTrendPoint, len(points))
|
||||||
|
for _, p := range points {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existing[p.BucketStart.UTC().Unix()] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*service.OpsThroughputTrendPoint, 0, int(last.Sub(first)/step)+1)
|
||||||
|
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
|
||||||
|
if p, ok := existing[cursor.Unix()]; ok && p != nil {
|
||||||
|
out = append(out, p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, &service.OpsThroughputTrendPoint{
|
||||||
|
BucketStart: cursor,
|
||||||
|
RequestCount: 0,
|
||||||
|
TokenConsumed: 0,
|
||||||
|
QPS: 0,
|
||||||
|
TPS: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetErrorTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsErrorTrendResponse, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
return nil, fmt.Errorf("nil filter")
|
||||||
|
}
|
||||||
|
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||||
|
return nil, fmt.Errorf("start_time/end_time required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if bucketSeconds <= 0 {
|
||||||
|
bucketSeconds = 60
|
||||||
|
}
|
||||||
|
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
|
||||||
|
bucketSeconds = 60
|
||||||
|
}
|
||||||
|
|
||||||
|
start := filter.StartTime.UTC()
|
||||||
|
end := filter.EndTime.UTC()
|
||||||
|
where, args, _ := buildErrorWhere(filter, start, end, 1)
|
||||||
|
bucketExpr := opsBucketExprForError(bucketSeconds)
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
` + bucketExpr + ` AS bucket,
|
||||||
|
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400) AS error_total,
|
||||||
|
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited) AS business_limited,
|
||||||
|
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited) AS error_sla,
|
||||||
|
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)) AS upstream_excl,
|
||||||
|
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429) AS upstream_429,
|
||||||
|
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529) AS upstream_529
|
||||||
|
FROM ops_error_logs
|
||||||
|
` + where + `
|
||||||
|
GROUP BY 1
|
||||||
|
ORDER BY 1 ASC`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
points := make([]*service.OpsErrorTrendPoint, 0, 256)
|
||||||
|
for rows.Next() {
|
||||||
|
var bucket time.Time
|
||||||
|
var total, businessLimited, sla, upstreamExcl, upstream429, upstream529 int64
|
||||||
|
if err := rows.Scan(&bucket, &total, &businessLimited, &sla, &upstreamExcl, &upstream429, &upstream529); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
points = append(points, &service.OpsErrorTrendPoint{
|
||||||
|
BucketStart: bucket.UTC(),
|
||||||
|
|
||||||
|
ErrorCountTotal: total,
|
||||||
|
BusinessLimitedCount: businessLimited,
|
||||||
|
ErrorCountSLA: sla,
|
||||||
|
|
||||||
|
UpstreamErrorCountExcl429529: upstreamExcl,
|
||||||
|
Upstream429Count: upstream429,
|
||||||
|
Upstream529Count: upstream529,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
points = fillOpsErrorTrendBuckets(start, end, bucketSeconds, points)
|
||||||
|
|
||||||
|
return &service.OpsErrorTrendResponse{
|
||||||
|
Bucket: opsBucketLabel(bucketSeconds),
|
||||||
|
Points: points,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillOpsErrorTrendBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsErrorTrendPoint) []*service.OpsErrorTrendPoint {
|
||||||
|
if bucketSeconds <= 0 {
|
||||||
|
bucketSeconds = 60
|
||||||
|
}
|
||||||
|
if !start.Before(end) {
|
||||||
|
return points
|
||||||
|
}
|
||||||
|
|
||||||
|
endMinus := end.Add(-time.Nanosecond)
|
||||||
|
if endMinus.Before(start) {
|
||||||
|
return points
|
||||||
|
}
|
||||||
|
|
||||||
|
first := opsFloorToBucketStart(start, bucketSeconds)
|
||||||
|
last := opsFloorToBucketStart(endMinus, bucketSeconds)
|
||||||
|
step := time.Duration(bucketSeconds) * time.Second
|
||||||
|
|
||||||
|
existing := make(map[int64]*service.OpsErrorTrendPoint, len(points))
|
||||||
|
for _, p := range points {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existing[p.BucketStart.UTC().Unix()] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*service.OpsErrorTrendPoint, 0, int(last.Sub(first)/step)+1)
|
||||||
|
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
|
||||||
|
if p, ok := existing[cursor.Unix()]; ok && p != nil {
|
||||||
|
out = append(out, p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, &service.OpsErrorTrendPoint{
|
||||||
|
BucketStart: cursor,
|
||||||
|
|
||||||
|
ErrorCountTotal: 0,
|
||||||
|
BusinessLimitedCount: 0,
|
||||||
|
ErrorCountSLA: 0,
|
||||||
|
|
||||||
|
UpstreamErrorCountExcl429529: 0,
|
||||||
|
Upstream429Count: 0,
|
||||||
|
Upstream529Count: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *opsRepository) GetErrorDistribution(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsErrorDistributionResponse, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
return nil, fmt.Errorf("nil filter")
|
||||||
|
}
|
||||||
|
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||||
|
return nil, fmt.Errorf("start_time/end_time required")
|
||||||
|
}
|
||||||
|
|
||||||
|
start := filter.StartTime.UTC()
|
||||||
|
end := filter.EndTime.UTC()
|
||||||
|
where, args, _ := buildErrorWhere(filter, start, end, 1)
|
||||||
|
|
||||||
|
q := `
|
||||||
|
SELECT
|
||||||
|
COALESCE(upstream_status_code, status_code, 0) AS status_code,
|
||||||
|
COUNT(*) AS total,
|
||||||
|
COUNT(*) FILTER (WHERE NOT is_business_limited) AS sla,
|
||||||
|
COUNT(*) FILTER (WHERE is_business_limited) AS business_limited
|
||||||
|
FROM ops_error_logs
|
||||||
|
` + where + `
|
||||||
|
AND COALESCE(status_code, 0) >= 400
|
||||||
|
GROUP BY 1
|
||||||
|
ORDER BY total DESC
|
||||||
|
LIMIT 20`
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
items := make([]*service.OpsErrorDistributionItem, 0, 16)
|
||||||
|
var total int64
|
||||||
|
for rows.Next() {
|
||||||
|
var statusCode int
|
||||||
|
var cntTotal, cntSLA, cntBiz int64
|
||||||
|
if err := rows.Scan(&statusCode, &cntTotal, &cntSLA, &cntBiz); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
total += cntTotal
|
||||||
|
items = append(items, &service.OpsErrorDistributionItem{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Total: cntTotal,
|
||||||
|
SLA: cntSLA,
|
||||||
|
BusinessLimited: cntBiz,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &service.OpsErrorDistributionResponse{
|
||||||
|
Total: total,
|
||||||
|
Items: items,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
50
backend/internal/repository/ops_repo_window_stats.go
Normal file
50
backend/internal/repository/ops_repo_window_stats.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r *opsRepository) GetWindowStats(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsWindowStats, error) {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, fmt.Errorf("nil ops repository")
|
||||||
|
}
|
||||||
|
if filter == nil {
|
||||||
|
return nil, fmt.Errorf("nil filter")
|
||||||
|
}
|
||||||
|
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||||
|
return nil, fmt.Errorf("start_time/end_time required")
|
||||||
|
}
|
||||||
|
|
||||||
|
start := filter.StartTime.UTC()
|
||||||
|
end := filter.EndTime.UTC()
|
||||||
|
if start.After(end) {
|
||||||
|
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||||
|
}
|
||||||
|
// Bound excessively large windows to prevent accidental heavy queries.
|
||||||
|
if end.Sub(start) > 24*time.Hour {
|
||||||
|
return nil, fmt.Errorf("window too large")
|
||||||
|
}
|
||||||
|
|
||||||
|
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
errorTotal, _, _, _, _, _, err := r.queryErrorCounts(ctx, filter, start, end)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &service.OpsWindowStats{
|
||||||
|
StartTime: start,
|
||||||
|
EndTime: end,
|
||||||
|
|
||||||
|
SuccessCount: successCount,
|
||||||
|
ErrorCountTotal: errorTotal,
|
||||||
|
TokenConsumed: tokenConsumed,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
276
backend/internal/repository/scheduler_cache.go
Normal file
276
backend/internal/repository/scheduler_cache.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
schedulerBucketSetKey = "sched:buckets"
|
||||||
|
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
|
||||||
|
schedulerAccountPrefix = "sched:acc:"
|
||||||
|
schedulerActivePrefix = "sched:active:"
|
||||||
|
schedulerReadyPrefix = "sched:ready:"
|
||||||
|
schedulerVersionPrefix = "sched:ver:"
|
||||||
|
schedulerSnapshotPrefix = "sched:"
|
||||||
|
schedulerLockPrefix = "sched:lock:"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schedulerCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
|
||||||
|
return &schedulerCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||||
|
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||||||
|
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if readyVal != "1" {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||||
|
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
|
||||||
|
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return []*service.Account{}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
keys = append(keys, schedulerAccountKey(id))
|
||||||
|
}
|
||||||
|
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts := make([]*service.Account, 0, len(values))
|
||||||
|
for _, val := range values {
|
||||||
|
if val == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
account, err := decodeCachedAccount(val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
accounts = append(accounts, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accounts, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||||
|
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||||
|
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
|
||||||
|
|
||||||
|
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||||||
|
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
versionStr := strconv.FormatInt(version, 10)
|
||||||
|
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
|
||||||
|
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
for _, account := range accounts {
|
||||||
|
payload, err := json.Marshal(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
|
||||||
|
}
|
||||||
|
if len(accounts) > 0 {
|
||||||
|
// 使用序号作为 score,保持数据库返回的排序语义。
|
||||||
|
members := make([]redis.Z, 0, len(accounts))
|
||||||
|
for idx, account := range accounts {
|
||||||
|
members = append(members, redis.Z{
|
||||||
|
Score: float64(idx),
|
||||||
|
Member: strconv.FormatInt(account.ID, 10),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
pipe.ZAdd(ctx, snapshotKey, members...)
|
||||||
|
} else {
|
||||||
|
pipe.Del(ctx, snapshotKey)
|
||||||
|
}
|
||||||
|
pipe.Set(ctx, activeKey, versionStr, 0)
|
||||||
|
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
|
||||||
|
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
|
||||||
|
if _, err := pipe.Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldActive != "" && oldActive != versionStr {
|
||||||
|
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||||
|
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return decodeCachedAccount(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
|
||||||
|
if account == nil || account.ID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
|
||||||
|
return c.rdb.Set(ctx, key, payload, 0).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||||
|
if accountID <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, 0, len(updates))
|
||||||
|
ids := make([]int64, 0, len(updates))
|
||||||
|
for id := range updates {
|
||||||
|
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
for i, val := range values {
|
||||||
|
if val == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
account, err := decodeCachedAccount(val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
account.LastUsedAt = ptrTime(updates[ids[i]])
|
||||||
|
updated, err := json.Marshal(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pipe.Set(ctx, keys[i], updated, 0)
|
||||||
|
}
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||||
|
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||||||
|
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||||
|
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out := make([]service.SchedulerBucket, 0, len(raw))
|
||||||
|
for _, entry := range raw {
|
||||||
|
bucket, ok := service.ParseSchedulerBucket(entry)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, bucket)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||||
|
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
id, err := strconv.ParseInt(val, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||||
|
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
|
||||||
|
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
|
||||||
|
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
func schedulerAccountKey(id string) string {
|
||||||
|
return schedulerAccountPrefix + id
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrTime(t time.Time) *time.Time {
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeCachedAccount(val any) (*service.Account, error) {
|
||||||
|
var payload []byte
|
||||||
|
switch raw := val.(type) {
|
||||||
|
case string:
|
||||||
|
payload = []byte(raw)
|
||||||
|
case []byte:
|
||||||
|
payload = raw
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected account cache type: %T", val)
|
||||||
|
}
|
||||||
|
var account service.Account
|
||||||
|
if err := json.Unmarshal(payload, &account); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &account, nil
|
||||||
|
}
|
||||||
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schedulerOutboxRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||||
|
return &schedulerOutboxRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *schedulerOutboxRepository) ListAfter(ctx context.Context, afterID int64, limit int) ([]service.SchedulerOutboxEvent, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
rows, err := r.db.QueryContext(ctx, `
|
||||||
|
SELECT id, event_type, account_id, group_id, payload, created_at
|
||||||
|
FROM scheduler_outbox
|
||||||
|
WHERE id > $1
|
||||||
|
ORDER BY id ASC
|
||||||
|
LIMIT $2
|
||||||
|
`, afterID, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = rows.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
events := make([]service.SchedulerOutboxEvent, 0, limit)
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
payloadRaw []byte
|
||||||
|
accountID sql.NullInt64
|
||||||
|
groupID sql.NullInt64
|
||||||
|
event service.SchedulerOutboxEvent
|
||||||
|
)
|
||||||
|
if err := rows.Scan(&event.ID, &event.EventType, &accountID, &groupID, &payloadRaw, &event.CreatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accountID.Valid {
|
||||||
|
v := accountID.Int64
|
||||||
|
event.AccountID = &v
|
||||||
|
}
|
||||||
|
if groupID.Valid {
|
||||||
|
v := groupID.Int64
|
||||||
|
event.GroupID = &v
|
||||||
|
}
|
||||||
|
if len(payloadRaw) > 0 {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(payloadRaw, &payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
event.Payload = payload
|
||||||
|
}
|
||||||
|
events = append(events, event)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *schedulerOutboxRepository) MaxID(ctx context.Context) (int64, error) {
|
||||||
|
var maxID int64
|
||||||
|
if err := r.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM scheduler_outbox").Scan(&maxID); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return maxID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType string, accountID *int64, groupID *int64, payload any) error {
|
||||||
|
if exec == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var payloadArg any
|
||||||
|
if payload != nil {
|
||||||
|
encoded, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payloadArg = encoded
|
||||||
|
}
|
||||||
|
_, err := exec.ExecContext(ctx, `
|
||||||
|
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
`, eventType, accountID, groupID, payloadArg)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := testRedis(t)
|
||||||
|
client := testEntClient(t)
|
||||||
|
|
||||||
|
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||||
|
|
||||||
|
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
||||||
|
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||||
|
cache := NewSchedulerCache(rdb)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
RunMode: config.RunModeStandard,
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
Scheduling: config.GatewaySchedulingConfig{
|
||||||
|
OutboxPollIntervalSeconds: 1,
|
||||||
|
FullRebuildIntervalSeconds: 0,
|
||||||
|
DbFallbackEnabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &service.Account{
|
||||||
|
Name: "outbox-replay-" + time.Now().Format("150405.000000"),
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 3,
|
||||||
|
Priority: 1,
|
||||||
|
Credentials: map[string]any{},
|
||||||
|
Extra: map[string]any{},
|
||||||
|
}
|
||||||
|
require.NoError(t, accountRepo.Create(ctx, account))
|
||||||
|
require.NoError(t, cache.SetAccount(ctx, account))
|
||||||
|
|
||||||
|
svc := service.NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, nil, cfg)
|
||||||
|
svc.Start()
|
||||||
|
t.Cleanup(svc.Stop)
|
||||||
|
|
||||||
|
require.NoError(t, accountRepo.UpdateLastUsed(ctx, account.ID))
|
||||||
|
updated, err := accountRepo.GetByID(ctx, account.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updated.LastUsedAt)
|
||||||
|
expectedUnix := updated.LastUsedAt.Unix()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
cached, err := cache.GetAccount(ctx, account.ID)
|
||||||
|
if err != nil || cached == nil || cached.LastUsedAt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return cached.LastUsedAt.Unix() == expectedUnix
|
||||||
|
}, 5*time.Second, 100*time.Millisecond)
|
||||||
|
}
|
||||||
80
backend/internal/repository/timeout_counter_cache.go
Normal file
80
backend/internal/repository/timeout_counter_cache.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const timeoutCounterPrefix = "timeout_count:account:"
|
||||||
|
|
||||||
|
// timeoutCounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
|
||||||
|
// 如果 key 不存在,则创建并设置过期时间
|
||||||
|
var timeoutCounterIncrScript = redis.NewScript(`
|
||||||
|
local key = KEYS[1]
|
||||||
|
local ttl = tonumber(ARGV[1])
|
||||||
|
|
||||||
|
local count = redis.call('INCR', key)
|
||||||
|
if count == 1 then
|
||||||
|
redis.call('EXPIRE', key, ttl)
|
||||||
|
end
|
||||||
|
|
||||||
|
return count
|
||||||
|
`)
|
||||||
|
|
||||||
|
type timeoutCounterCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTimeoutCounterCache 创建超时计数器缓存实例
|
||||||
|
func NewTimeoutCounterCache(rdb *redis.Client) service.TimeoutCounterCache {
|
||||||
|
return &timeoutCounterCache{rdb: rdb}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
|
||||||
|
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
|
||||||
|
func (c *timeoutCounterCache) IncrementTimeoutCount(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||||
|
|
||||||
|
ttlSeconds := windowMinutes * 60
|
||||||
|
if ttlSeconds < 60 {
|
||||||
|
ttlSeconds = 60 // 最小1分钟
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := timeoutCounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("increment timeout count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTimeoutCount 获取账户当前的超时计数
|
||||||
|
func (c *timeoutCounterCache) GetTimeoutCount(ctx context.Context, accountID int64) (int64, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||||
|
|
||||||
|
val, err := c.rdb.Get(ctx, key).Int64()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get timeout count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetTimeoutCount 重置账户的超时计数
|
||||||
|
func (c *timeoutCounterCache) ResetTimeoutCount(ctx context.Context, accountID int64) error {
|
||||||
|
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTimeoutCountTTL 获取计数器剩余过期时间
|
||||||
|
func (c *timeoutCounterCache) GetTimeoutCountTTL(ctx context.Context, accountID int64) (time.Duration, error) {
|
||||||
|
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||||
|
return c.rdb.TTL(ctx, key).Result()
|
||||||
|
}
|
||||||
@@ -269,16 +269,60 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
|||||||
type DashboardStats = usagestats.DashboardStats
|
type DashboardStats = usagestats.DashboardStats
|
||||||
|
|
||||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||||
var stats DashboardStats
|
stats := &DashboardStats{}
|
||||||
today := timezone.Today()
|
now := time.Now().UTC()
|
||||||
now := time.Now()
|
todayUTC := truncateToDayUTC(now)
|
||||||
|
|
||||||
// 合并用户统计查询
|
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stats.Rpm = rpm
|
||||||
|
stats.Tpm = tpm
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) {
|
||||||
|
startUTC := start.UTC()
|
||||||
|
endUTC := end.UTC()
|
||||||
|
if !endUTC.After(startUTC) {
|
||||||
|
return nil, errors.New("统计时间范围无效")
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := &DashboardStats{}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
todayUTC := truncateToDayUTC(now)
|
||||||
|
|
||||||
|
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stats.Rpm = rpm
|
||||||
|
stats.Tpm = tpm
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||||||
userStatsQuery := `
|
userStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_users,
|
COUNT(*) as total_users,
|
||||||
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users,
|
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users
|
||||||
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users
|
|
||||||
FROM users
|
FROM users
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
`
|
`
|
||||||
@@ -286,15 +330,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
ctx,
|
ctx,
|
||||||
r.sql,
|
r.sql,
|
||||||
userStatsQuery,
|
userStatsQuery,
|
||||||
[]any{today, today},
|
[]any{todayUTC},
|
||||||
&stats.TotalUsers,
|
&stats.TotalUsers,
|
||||||
&stats.TodayNewUsers,
|
&stats.TodayNewUsers,
|
||||||
&stats.ActiveUsers,
|
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 合并API Key统计查询
|
|
||||||
apiKeyStatsQuery := `
|
apiKeyStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_api_keys,
|
COUNT(*) as total_api_keys,
|
||||||
@@ -310,10 +352,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
&stats.TotalAPIKeys,
|
&stats.TotalAPIKeys,
|
||||||
&stats.ActiveAPIKeys,
|
&stats.ActiveAPIKeys,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 合并账户统计查询
|
|
||||||
accountStatsQuery := `
|
accountStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_accounts,
|
COUNT(*) as total_accounts,
|
||||||
@@ -335,22 +376,26 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
&stats.RateLimitAccounts,
|
&stats.RateLimitAccounts,
|
||||||
&stats.OverloadAccounts,
|
&stats.OverloadAccounts,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 累计 Token 统计
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||||||
totalStatsQuery := `
|
totalStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_requests,
|
COALESCE(SUM(total_requests), 0) as total_requests,
|
||||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_dashboard_daily
|
||||||
`
|
`
|
||||||
|
var totalDurationMs int64
|
||||||
if err := scanSingleRow(
|
if err := scanSingleRow(
|
||||||
ctx,
|
ctx,
|
||||||
r.sql,
|
r.sql,
|
||||||
@@ -363,13 +408,100 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
&stats.TotalCacheReadTokens,
|
&stats.TotalCacheReadTokens,
|
||||||
&stats.TotalCost,
|
&stats.TotalCost,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalActualCost,
|
||||||
&stats.AverageDurationMs,
|
&totalDurationMs,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||||
|
if stats.TotalRequests > 0 {
|
||||||
|
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||||
|
}
|
||||||
|
|
||||||
// 今日 Token 统计
|
todayStatsQuery := `
|
||||||
|
SELECT
|
||||||
|
total_requests as today_requests,
|
||||||
|
input_tokens as today_input_tokens,
|
||||||
|
output_tokens as today_output_tokens,
|
||||||
|
cache_creation_tokens as today_cache_creation_tokens,
|
||||||
|
cache_read_tokens as today_cache_read_tokens,
|
||||||
|
total_cost as today_cost,
|
||||||
|
actual_cost as today_actual_cost,
|
||||||
|
active_users as active_users
|
||||||
|
FROM usage_dashboard_daily
|
||||||
|
WHERE bucket_date = $1::date
|
||||||
|
`
|
||||||
|
if err := scanSingleRow(
|
||||||
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
todayStatsQuery,
|
||||||
|
[]any{todayUTC},
|
||||||
|
&stats.TodayRequests,
|
||||||
|
&stats.TodayInputTokens,
|
||||||
|
&stats.TodayOutputTokens,
|
||||||
|
&stats.TodayCacheCreationTokens,
|
||||||
|
&stats.TodayCacheReadTokens,
|
||||||
|
&stats.TodayCost,
|
||||||
|
&stats.TodayActualCost,
|
||||||
|
&stats.ActiveUsers,
|
||||||
|
); err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||||
|
|
||||||
|
hourlyActiveQuery := `
|
||||||
|
SELECT active_users
|
||||||
|
FROM usage_dashboard_hourly
|
||||||
|
WHERE bucket_start = $1
|
||||||
|
`
|
||||||
|
hourStart := now.UTC().Truncate(time.Hour)
|
||||||
|
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
|
||||||
|
totalStatsQuery := `
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_requests,
|
||||||
|
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
|
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
`
|
||||||
|
var totalDurationMs int64
|
||||||
|
if err := scanSingleRow(
|
||||||
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
totalStatsQuery,
|
||||||
|
[]any{startUTC, endUTC},
|
||||||
|
&stats.TotalRequests,
|
||||||
|
&stats.TotalInputTokens,
|
||||||
|
&stats.TotalOutputTokens,
|
||||||
|
&stats.TotalCacheCreationTokens,
|
||||||
|
&stats.TotalCacheReadTokens,
|
||||||
|
&stats.TotalCost,
|
||||||
|
&stats.TotalActualCost,
|
||||||
|
&totalDurationMs,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||||
|
if stats.TotalRequests > 0 {
|
||||||
|
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
todayEnd := todayUTC.Add(24 * time.Hour)
|
||||||
todayStatsQuery := `
|
todayStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as today_requests,
|
COUNT(*) as today_requests,
|
||||||
@@ -380,13 +512,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
COALESCE(SUM(total_cost), 0) as today_cost,
|
COALESCE(SUM(total_cost), 0) as today_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
`
|
`
|
||||||
if err := scanSingleRow(
|
if err := scanSingleRow(
|
||||||
ctx,
|
ctx,
|
||||||
r.sql,
|
r.sql,
|
||||||
todayStatsQuery,
|
todayStatsQuery,
|
||||||
[]any{today},
|
[]any{todayUTC, todayEnd},
|
||||||
&stats.TodayRequests,
|
&stats.TodayRequests,
|
||||||
&stats.TodayInputTokens,
|
&stats.TodayInputTokens,
|
||||||
&stats.TodayOutputTokens,
|
&stats.TodayOutputTokens,
|
||||||
@@ -395,19 +527,31 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
&stats.TodayCost,
|
&stats.TodayCost,
|
||||||
&stats.TodayActualCost,
|
&stats.TodayActualCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||||
|
|
||||||
// 性能指标:RPM 和 TPM(最近1分钟,全局)
|
activeUsersQuery := `
|
||||||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
SELECT COUNT(DISTINCT user_id) as active_users
|
||||||
if err != nil {
|
FROM usage_logs
|
||||||
return nil, err
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
`
|
||||||
|
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
stats.Rpm = rpm
|
|
||||||
stats.Tpm = tpm
|
|
||||||
|
|
||||||
return &stats, nil
|
hourStart := now.UTC().Truncate(time.Hour)
|
||||||
|
hourEnd := hourStart.Add(time.Hour)
|
||||||
|
hourlyActiveQuery := `
|
||||||
|
SELECT COUNT(DISTINCT user_id) as active_users
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
`
|
||||||
|
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@@ -198,14 +197,14 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
|
|||||||
// --- GetDashboardStats ---
|
// --- GetDashboardStats ---
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||||
now := time.Now()
|
now := time.Now().UTC()
|
||||||
todayStart := timezone.Today()
|
todayStart := truncateToDayUTC(now)
|
||||||
baseStats, err := s.repo.GetDashboardStats(s.ctx)
|
baseStats, err := s.repo.GetDashboardStats(s.ctx)
|
||||||
s.Require().NoError(err, "GetDashboardStats base")
|
s.Require().NoError(err, "GetDashboardStats base")
|
||||||
|
|
||||||
userToday := mustCreateUser(s.T(), s.client, &service.User{
|
userToday := mustCreateUser(s.T(), s.client, &service.User{
|
||||||
Email: "today@example.com",
|
Email: "today@example.com",
|
||||||
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
|
CreatedAt: testMaxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
})
|
})
|
||||||
userOld := mustCreateUser(s.T(), s.client, &service.User{
|
userOld := mustCreateUser(s.T(), s.client, &service.User{
|
||||||
@@ -238,7 +237,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
|||||||
TotalCost: 1.5,
|
TotalCost: 1.5,
|
||||||
ActualCost: 1.2,
|
ActualCost: 1.2,
|
||||||
DurationMs: &d1,
|
DurationMs: &d1,
|
||||||
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
|
CreatedAt: testMaxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
|
||||||
}
|
}
|
||||||
_, err = s.repo.Create(s.ctx, logToday)
|
_, err = s.repo.Create(s.ctx, logToday)
|
||||||
s.Require().NoError(err, "Create logToday")
|
s.Require().NoError(err, "Create logToday")
|
||||||
@@ -273,6 +272,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
|||||||
_, err = s.repo.Create(s.ctx, logPerf)
|
_, err = s.repo.Create(s.ctx, logPerf)
|
||||||
s.Require().NoError(err, "Create logPerf")
|
s.Require().NoError(err, "Create logPerf")
|
||||||
|
|
||||||
|
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
|
||||||
|
aggStart := todayStart.Add(-2 * time.Hour)
|
||||||
|
aggEnd := now.Add(2 * time.Minute)
|
||||||
|
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd), "AggregateRange")
|
||||||
|
|
||||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||||
s.Require().NoError(err, "GetDashboardStats")
|
s.Require().NoError(err, "GetDashboardStats")
|
||||||
|
|
||||||
@@ -303,6 +307,80 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
|||||||
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
todayStart := truncateToDayUTC(now)
|
||||||
|
rangeStart := todayStart.Add(-24 * time.Hour)
|
||||||
|
rangeEnd := now.Add(1 * time.Second)
|
||||||
|
|
||||||
|
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u2@test.com"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-range-1", Name: "k1"})
|
||||||
|
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-range-2", Name: "k2"})
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-range"})
|
||||||
|
|
||||||
|
d1, d2, d3 := 100, 200, 300
|
||||||
|
logOutside := &service.UsageLog{
|
||||||
|
UserID: user1.ID,
|
||||||
|
APIKeyID: apiKey1.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 7,
|
||||||
|
OutputTokens: 8,
|
||||||
|
TotalCost: 0.8,
|
||||||
|
ActualCost: 0.7,
|
||||||
|
DurationMs: &d3,
|
||||||
|
CreatedAt: rangeStart.Add(-1 * time.Hour),
|
||||||
|
}
|
||||||
|
_, err := s.repo.Create(s.ctx, logOutside)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
logRange := &service.UsageLog{
|
||||||
|
UserID: user1.ID,
|
||||||
|
APIKeyID: apiKey1.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
CacheCreationTokens: 1,
|
||||||
|
CacheReadTokens: 2,
|
||||||
|
TotalCost: 1.0,
|
||||||
|
ActualCost: 0.9,
|
||||||
|
DurationMs: &d1,
|
||||||
|
CreatedAt: rangeStart.Add(2 * time.Hour),
|
||||||
|
}
|
||||||
|
_, err = s.repo.Create(s.ctx, logRange)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
logToday := &service.UsageLog{
|
||||||
|
UserID: user2.ID,
|
||||||
|
APIKeyID: apiKey2.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 5,
|
||||||
|
OutputTokens: 6,
|
||||||
|
CacheReadTokens: 1,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
DurationMs: &d2,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
_, err = s.repo.Create(s.ctx, logToday)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
stats, err := s.repo.GetDashboardStatsWithRange(s.ctx, rangeStart, rangeEnd)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(int64(2), stats.TotalRequests)
|
||||||
|
s.Require().Equal(int64(15), stats.TotalInputTokens)
|
||||||
|
s.Require().Equal(int64(26), stats.TotalOutputTokens)
|
||||||
|
s.Require().Equal(int64(1), stats.TotalCacheCreationTokens)
|
||||||
|
s.Require().Equal(int64(3), stats.TotalCacheReadTokens)
|
||||||
|
s.Require().Equal(int64(45), stats.TotalTokens)
|
||||||
|
s.Require().Equal(1.5, stats.TotalCost)
|
||||||
|
s.Require().Equal(1.4, stats.TotalActualCost)
|
||||||
|
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
|
||||||
|
}
|
||||||
|
|
||||||
// --- GetUserDashboardStats ---
|
// --- GetUserDashboardStats ---
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||||
@@ -333,6 +411,159 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
|||||||
s.Require().Equal(int64(30), stats.Tokens)
|
s.Require().Equal(int64(30), stats.Tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
||||||
|
now := time.Now().UTC().Truncate(time.Second)
|
||||||
|
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
|
||||||
|
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
|
||||||
|
dayStart := truncateToDayUTC(now)
|
||||||
|
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
|
||||||
|
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
|
||||||
|
// 如果当前时间早于 hour2,则使用昨天的时间
|
||||||
|
if now.Before(hour2.Add(time.Hour)) {
|
||||||
|
dayStart = dayStart.Add(-24 * time.Hour)
|
||||||
|
hour1 = dayStart.Add(2 * time.Hour)
|
||||||
|
hour2 = dayStart.Add(3 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u2@test.com"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-agg-1", Name: "k1"})
|
||||||
|
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-agg-2", Name: "k2"})
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-agg"})
|
||||||
|
|
||||||
|
d1, d2, d3 := 100, 200, 150
|
||||||
|
log1 := &service.UsageLog{
|
||||||
|
UserID: user1.ID,
|
||||||
|
APIKeyID: apiKey1.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
CacheCreationTokens: 2,
|
||||||
|
CacheReadTokens: 1,
|
||||||
|
TotalCost: 1.0,
|
||||||
|
ActualCost: 0.9,
|
||||||
|
DurationMs: &d1,
|
||||||
|
CreatedAt: hour1.Add(5 * time.Minute),
|
||||||
|
}
|
||||||
|
_, err := s.repo.Create(s.ctx, log1)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
log2 := &service.UsageLog{
|
||||||
|
UserID: user1.ID,
|
||||||
|
APIKeyID: apiKey1.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 5,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
DurationMs: &d2,
|
||||||
|
CreatedAt: hour1.Add(20 * time.Minute),
|
||||||
|
}
|
||||||
|
_, err = s.repo.Create(s.ctx, log2)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
log3 := &service.UsageLog{
|
||||||
|
UserID: user2.ID,
|
||||||
|
APIKeyID: apiKey2.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 7,
|
||||||
|
OutputTokens: 8,
|
||||||
|
TotalCost: 0.7,
|
||||||
|
ActualCost: 0.7,
|
||||||
|
DurationMs: &d3,
|
||||||
|
CreatedAt: hour2.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
_, err = s.repo.Create(s.ctx, log3)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
|
||||||
|
aggStart := hour1.Add(-5 * time.Minute)
|
||||||
|
aggEnd := hour2.Add(time.Hour) // 确保覆盖 hour2 的所有数据
|
||||||
|
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd))
|
||||||
|
|
||||||
|
type hourlyRow struct {
|
||||||
|
totalRequests int64
|
||||||
|
inputTokens int64
|
||||||
|
outputTokens int64
|
||||||
|
cacheCreationTokens int64
|
||||||
|
cacheReadTokens int64
|
||||||
|
totalCost float64
|
||||||
|
actualCost float64
|
||||||
|
totalDurationMs int64
|
||||||
|
activeUsers int64
|
||||||
|
}
|
||||||
|
fetchHourly := func(bucketStart time.Time) hourlyRow {
|
||||||
|
var row hourlyRow
|
||||||
|
err := scanSingleRow(s.ctx, s.tx, `
|
||||||
|
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
|
||||||
|
total_cost, actual_cost, total_duration_ms, active_users
|
||||||
|
FROM usage_dashboard_hourly
|
||||||
|
WHERE bucket_start = $1
|
||||||
|
`, []any{bucketStart}, &row.totalRequests, &row.inputTokens, &row.outputTokens,
|
||||||
|
&row.cacheCreationTokens, &row.cacheReadTokens, &row.totalCost, &row.actualCost,
|
||||||
|
&row.totalDurationMs, &row.activeUsers,
|
||||||
|
)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
|
||||||
|
hour1Row := fetchHourly(hour1)
|
||||||
|
s.Require().Equal(int64(2), hour1Row.totalRequests)
|
||||||
|
s.Require().Equal(int64(15), hour1Row.inputTokens)
|
||||||
|
s.Require().Equal(int64(25), hour1Row.outputTokens)
|
||||||
|
s.Require().Equal(int64(2), hour1Row.cacheCreationTokens)
|
||||||
|
s.Require().Equal(int64(1), hour1Row.cacheReadTokens)
|
||||||
|
s.Require().Equal(1.5, hour1Row.totalCost)
|
||||||
|
s.Require().Equal(1.4, hour1Row.actualCost)
|
||||||
|
s.Require().Equal(int64(300), hour1Row.totalDurationMs)
|
||||||
|
s.Require().Equal(int64(1), hour1Row.activeUsers)
|
||||||
|
|
||||||
|
hour2Row := fetchHourly(hour2)
|
||||||
|
s.Require().Equal(int64(1), hour2Row.totalRequests)
|
||||||
|
s.Require().Equal(int64(7), hour2Row.inputTokens)
|
||||||
|
s.Require().Equal(int64(8), hour2Row.outputTokens)
|
||||||
|
s.Require().Equal(int64(0), hour2Row.cacheCreationTokens)
|
||||||
|
s.Require().Equal(int64(0), hour2Row.cacheReadTokens)
|
||||||
|
s.Require().Equal(0.7, hour2Row.totalCost)
|
||||||
|
s.Require().Equal(0.7, hour2Row.actualCost)
|
||||||
|
s.Require().Equal(int64(150), hour2Row.totalDurationMs)
|
||||||
|
s.Require().Equal(int64(1), hour2Row.activeUsers)
|
||||||
|
|
||||||
|
var daily struct {
|
||||||
|
totalRequests int64
|
||||||
|
inputTokens int64
|
||||||
|
outputTokens int64
|
||||||
|
cacheCreationTokens int64
|
||||||
|
cacheReadTokens int64
|
||||||
|
totalCost float64
|
||||||
|
actualCost float64
|
||||||
|
totalDurationMs int64
|
||||||
|
activeUsers int64
|
||||||
|
}
|
||||||
|
err = scanSingleRow(s.ctx, s.tx, `
|
||||||
|
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
|
||||||
|
total_cost, actual_cost, total_duration_ms, active_users
|
||||||
|
FROM usage_dashboard_daily
|
||||||
|
WHERE bucket_date = $1::date
|
||||||
|
`, []any{dayStart}, &daily.totalRequests, &daily.inputTokens, &daily.outputTokens,
|
||||||
|
&daily.cacheCreationTokens, &daily.cacheReadTokens, &daily.totalCost, &daily.actualCost,
|
||||||
|
&daily.totalDurationMs, &daily.activeUsers,
|
||||||
|
)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(int64(3), daily.totalRequests)
|
||||||
|
s.Require().Equal(int64(22), daily.inputTokens)
|
||||||
|
s.Require().Equal(int64(33), daily.outputTokens)
|
||||||
|
s.Require().Equal(int64(2), daily.cacheCreationTokens)
|
||||||
|
s.Require().Equal(int64(1), daily.cacheReadTokens)
|
||||||
|
s.Require().Equal(2.2, daily.totalCost)
|
||||||
|
s.Require().Equal(2.1, daily.actualCost)
|
||||||
|
s.Require().Equal(int64(450), daily.totalDurationMs)
|
||||||
|
s.Require().Equal(int64(2), daily.activeUsers)
|
||||||
|
}
|
||||||
|
|
||||||
// --- GetBatchUserUsageStats ---
|
// --- GetBatchUserUsageStats ---
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||||
@@ -398,7 +629,7 @@ func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
|||||||
s.Require().Equal(int64(45), stats.TotalOutputTokens)
|
s.Require().Equal(int64(45), stats.TotalOutputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
func maxTime(a, b time.Time) time.Time {
|
func testMaxTime(a, b time.Time) time.Time {
|
||||||
if a.After(b) {
|
if a.After(b) {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,7 +47,9 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewRedeemCodeRepository,
|
NewRedeemCodeRepository,
|
||||||
NewPromoCodeRepository,
|
NewPromoCodeRepository,
|
||||||
NewUsageLogRepository,
|
NewUsageLogRepository,
|
||||||
|
NewDashboardAggregationRepository,
|
||||||
NewSettingRepository,
|
NewSettingRepository,
|
||||||
|
NewOpsRepository,
|
||||||
NewUserSubscriptionRepository,
|
NewUserSubscriptionRepository,
|
||||||
NewUserAttributeDefinitionRepository,
|
NewUserAttributeDefinitionRepository,
|
||||||
NewUserAttributeValueRepository,
|
NewUserAttributeValueRepository,
|
||||||
@@ -57,12 +59,16 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewBillingCache,
|
NewBillingCache,
|
||||||
NewAPIKeyCache,
|
NewAPIKeyCache,
|
||||||
NewTempUnschedCache,
|
NewTempUnschedCache,
|
||||||
|
NewTimeoutCounterCache,
|
||||||
ProvideConcurrencyCache,
|
ProvideConcurrencyCache,
|
||||||
|
NewDashboardCache,
|
||||||
NewEmailCache,
|
NewEmailCache,
|
||||||
NewIdentityCache,
|
NewIdentityCache,
|
||||||
NewRedeemCache,
|
NewRedeemCache,
|
||||||
NewUpdateCache,
|
NewUpdateCache,
|
||||||
NewGeminiTokenCache,
|
NewGeminiTokenCache,
|
||||||
|
NewSchedulerCache,
|
||||||
|
NewSchedulerOutboxRepository,
|
||||||
|
|
||||||
// HTTP service ports (DI Strategy A: return interface directly)
|
// HTTP service ports (DI Strategy A: return interface directly)
|
||||||
NewTurnstileVerifier,
|
NewTurnstileVerifier,
|
||||||
|
|||||||
@@ -262,11 +262,11 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
name: "GET /api/v1/admin/settings",
|
name: "GET /api/v1/admin/settings",
|
||||||
setup: func(t *testing.T, deps *contractDeps) {
|
setup: func(t *testing.T, deps *contractDeps) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
deps.settingRepo.SetAll(map[string]string{
|
deps.settingRepo.SetAll(map[string]string{
|
||||||
service.SettingKeyRegistrationEnabled: "true",
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
service.SettingKeyEmailVerifyEnabled: "false",
|
service.SettingKeyEmailVerifyEnabled: "false",
|
||||||
|
|
||||||
service.SettingKeySMTPHost: "smtp.example.com",
|
service.SettingKeySMTPHost: "smtp.example.com",
|
||||||
service.SettingKeySMTPPort: "587",
|
service.SettingKeySMTPPort: "587",
|
||||||
service.SettingKeySMTPUsername: "user",
|
service.SettingKeySMTPUsername: "user",
|
||||||
service.SettingKeySMTPPassword: "secret",
|
service.SettingKeySMTPPassword: "secret",
|
||||||
@@ -285,10 +285,15 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
service.SettingKeyContactInfo: "support",
|
service.SettingKeyContactInfo: "support",
|
||||||
service.SettingKeyDocURL: "https://docs.example.com",
|
service.SettingKeyDocURL: "https://docs.example.com",
|
||||||
|
|
||||||
service.SettingKeyDefaultConcurrency: "5",
|
service.SettingKeyDefaultConcurrency: "5",
|
||||||
service.SettingKeyDefaultBalance: "1.25",
|
service.SettingKeyDefaultBalance: "1.25",
|
||||||
})
|
|
||||||
},
|
service.SettingKeyOpsMonitoringEnabled: "false",
|
||||||
|
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
|
||||||
|
service.SettingKeyOpsQueryModeDefault: "auto",
|
||||||
|
service.SettingKeyOpsMetricsIntervalSeconds: "60",
|
||||||
|
})
|
||||||
|
},
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: "/api/v1/admin/settings",
|
path: "/api/v1/admin/settings",
|
||||||
wantStatus: http.StatusOK,
|
wantStatus: http.StatusOK,
|
||||||
@@ -309,13 +314,17 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"turnstile_site_key": "site-key",
|
"turnstile_site_key": "site-key",
|
||||||
"turnstile_secret_key_configured": true,
|
"turnstile_secret_key_configured": true,
|
||||||
"linuxdo_connect_enabled": false,
|
"linuxdo_connect_enabled": false,
|
||||||
"linuxdo_connect_client_id": "",
|
"linuxdo_connect_client_id": "",
|
||||||
"linuxdo_connect_client_secret_configured": false,
|
"linuxdo_connect_client_secret_configured": false,
|
||||||
"linuxdo_connect_redirect_url": "",
|
"linuxdo_connect_redirect_url": "",
|
||||||
"site_name": "Sub2API",
|
"ops_monitoring_enabled": false,
|
||||||
"site_logo": "",
|
"ops_realtime_monitoring_enabled": true,
|
||||||
"site_subtitle": "Subtitle",
|
"ops_query_mode_default": "auto",
|
||||||
"api_base_url": "https://api.example.com",
|
"ops_metrics_interval_seconds": 60,
|
||||||
|
"site_name": "Sub2API",
|
||||||
|
"site_logo": "",
|
||||||
|
"site_subtitle": "Subtitle",
|
||||||
|
"api_base_url": "https://api.example.com",
|
||||||
"contact_info": "support",
|
"contact_info": "support",
|
||||||
"doc_url": "https://docs.example.com",
|
"doc_url": "https://docs.example.com",
|
||||||
"default_concurrency": 5,
|
"default_concurrency": 5,
|
||||||
@@ -331,6 +340,30 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "POST /api/v1/admin/accounts/bulk-update",
|
||||||
|
method: http.MethodPost,
|
||||||
|
path: "/api/v1/admin/accounts/bulk-update",
|
||||||
|
body: `{"account_ids":[101,102],"schedulable":false}`,
|
||||||
|
headers: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantJSON: `{
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": {
|
||||||
|
"success": 2,
|
||||||
|
"failed": 0,
|
||||||
|
"success_ids": [101, 102],
|
||||||
|
"failed_ids": [],
|
||||||
|
"results": [
|
||||||
|
{"account_id": 101, "success": true},
|
||||||
|
{"account_id": 102, "success": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -382,6 +415,9 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
apiKeyCache := stubApiKeyCache{}
|
apiKeyCache := stubApiKeyCache{}
|
||||||
groupRepo := stubGroupRepo{}
|
groupRepo := stubGroupRepo{}
|
||||||
userSubRepo := stubUserSubscriptionRepo{}
|
userSubRepo := stubUserSubscriptionRepo{}
|
||||||
|
accountRepo := stubAccountRepo{}
|
||||||
|
proxyRepo := stubProxyRepo{}
|
||||||
|
redeemRepo := stubRedeemCodeRepo{}
|
||||||
|
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Default: config.DefaultConfig{
|
Default: config.DefaultConfig{
|
||||||
@@ -390,19 +426,21 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
RunMode: config.RunModeStandard,
|
RunMode: config.RunModeStandard,
|
||||||
}
|
}
|
||||||
|
|
||||||
userService := service.NewUserService(userRepo)
|
userService := service.NewUserService(userRepo, nil)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||||
|
|
||||||
usageRepo := newStubUsageLogRepo()
|
usageRepo := newStubUsageLogRepo()
|
||||||
usageService := service.NewUsageService(usageRepo, userRepo, nil)
|
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||||
|
|
||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||||
|
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
jwtAuth := func(c *gin.Context) {
|
jwtAuth := func(c *gin.Context) {
|
||||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||||
@@ -442,6 +480,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
v1Admin := v1.Group("/admin")
|
v1Admin := v1.Group("/admin")
|
||||||
v1Admin.Use(adminAuth)
|
v1Admin.Use(adminAuth)
|
||||||
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
|
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
|
||||||
|
v1Admin.POST("/accounts/bulk-update", adminAccountHandler.BulkUpdate)
|
||||||
|
|
||||||
return &contractDeps{
|
return &contractDeps{
|
||||||
now: now,
|
now: now,
|
||||||
@@ -566,6 +605,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type stubGroupRepo struct{}
|
type stubGroupRepo struct{}
|
||||||
|
|
||||||
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
||||||
@@ -620,6 +671,235 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type stubAccountRepo struct {
|
||||||
|
bulkUpdateIDs []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) Create(ctx context.Context, account *service.Account) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||||
|
return nil, service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||||
|
return false, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) Delete(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||||
|
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||||
|
return int64(len(ids)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubProxyRepo struct{}
|
||||||
|
|
||||||
|
func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||||
|
return nil, service.ErrProxyNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) Delete(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListActive(ctx context.Context) ([]service.Proxy, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||||
|
return false, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubRedeemCodeRepo struct{}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||||
|
return nil, service.ErrRedeemCodeNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
|
||||||
|
return nil, service.ErrRedeemCodeNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) Delete(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUserSubscriptionRepo struct{}
|
type stubUserSubscriptionRepo struct{}
|
||||||
|
|
||||||
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||||
@@ -738,12 +1018,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
|
|||||||
return &clone, nil
|
return &clone, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
key, ok := r.byID[id]
|
key, ok := r.byID[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, service.ErrAPIKeyNotFound
|
return "", 0, service.ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
return key.UserID, nil
|
return key.Key, key.UserID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -755,6 +1035,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
|
|||||||
return &clone, nil
|
return &clone, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return r.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
return errors.New("nil key")
|
return errors.New("nil key")
|
||||||
@@ -869,6 +1153,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUsageLogRepo struct {
|
type stubUsageLogRepo struct {
|
||||||
userLogs map[int64][]service.UsageLog
|
userLogs map[int64][]service.UsageLog
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ func ProvideRouter(
|
|||||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||||
apiKeyService *service.APIKeyService,
|
apiKeyService *service.APIKeyService,
|
||||||
subscriptionService *service.SubscriptionService,
|
subscriptionService *service.SubscriptionService,
|
||||||
|
opsService *service.OpsService,
|
||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
) *gin.Engine {
|
) *gin.Engine {
|
||||||
@@ -50,7 +51,7 @@ func ProvideRouter(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, settingService, cfg, redisClient)
|
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvideHTTPServer 提供 HTTP 服务器
|
// ProvideHTTPServer 提供 HTTP 服务器
|
||||||
|
|||||||
@@ -30,6 +30,20 @@ func adminAuth(
|
|||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
) gin.HandlerFunc {
|
) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
// WebSocket upgrade requests cannot set Authorization headers in browsers.
|
||||||
|
// For admin WebSocket endpoints (e.g. Ops realtime), allow passing the JWT via
|
||||||
|
// Sec-WebSocket-Protocol (subprotocol list) using a prefixed token item:
|
||||||
|
// Sec-WebSocket-Protocol: sub2api-admin, jwt.<token>
|
||||||
|
if isWebSocketUpgradeRequest(c) {
|
||||||
|
if token := extractJWTFromWebSocketSubprotocol(c); token != "" {
|
||||||
|
if !validateJWTForAdmin(c, token, authService, userService) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 检查 x-api-key header(Admin API Key 认证)
|
// 检查 x-api-key header(Admin API Key 认证)
|
||||||
apiKey := c.GetHeader("x-api-key")
|
apiKey := c.GetHeader("x-api-key")
|
||||||
if apiKey != "" {
|
if apiKey != "" {
|
||||||
@@ -58,6 +72,44 @@ func adminAuth(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isWebSocketUpgradeRequest(c *gin.Context) bool {
|
||||||
|
if c == nil || c.Request == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// RFC6455 handshake uses:
|
||||||
|
// Connection: Upgrade
|
||||||
|
// Upgrade: websocket
|
||||||
|
upgrade := strings.ToLower(strings.TrimSpace(c.GetHeader("Upgrade")))
|
||||||
|
if upgrade != "websocket" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
connection := strings.ToLower(c.GetHeader("Connection"))
|
||||||
|
return strings.Contains(connection, "upgrade")
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractJWTFromWebSocketSubprotocol(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(c.GetHeader("Sec-WebSocket-Protocol"))
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// The header is a comma-separated list of tokens. We reserve the prefix "jwt."
|
||||||
|
// for carrying the admin JWT.
|
||||||
|
for _, part := range strings.Split(raw, ",") {
|
||||||
|
p := strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(p, "jwt.") {
|
||||||
|
token := strings.TrimSpace(strings.TrimPrefix(p, "jwt."))
|
||||||
|
if token != "" {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// validateAdminAPIKey 验证管理员 API Key
|
// validateAdminAPIKey 验证管理员 API Key
|
||||||
func validateAdminAPIKey(
|
func validateAdminAPIKey(
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
|||||||
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return "", 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
if f.getByKey == nil {
|
if f.getByKey == nil {
|
||||||
@@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK
|
|||||||
}
|
}
|
||||||
return f.getByKey(ctx, key)
|
return f.getByKey(ctx, key)
|
||||||
}
|
}
|
||||||
|
func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return f.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64
|
|||||||
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type googleErrorResponse struct {
|
type googleErrorResponse struct {
|
||||||
Error struct {
|
Error struct {
|
||||||
|
|||||||
@@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return "", 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return r.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUserSubscriptionRepo struct {
|
type stubUserSubscriptionRepo struct {
|
||||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||||
|
|||||||
30
backend/internal/server/middleware/client_request_id.go
Normal file
30
backend/internal/server/middleware/client_request_id.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
|
||||||
|
//
|
||||||
|
// This is used by the Ops monitoring module for end-to-end request correlation.
|
||||||
|
func ClientRequestID() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if c.Request == nil {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := uuid.New().String()
|
||||||
|
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id))
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,6 +23,7 @@ func SetupRouter(
|
|||||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||||
apiKeyService *service.APIKeyService,
|
apiKeyService *service.APIKeyService,
|
||||||
subscriptionService *service.SubscriptionService,
|
subscriptionService *service.SubscriptionService,
|
||||||
|
opsService *service.OpsService,
|
||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
@@ -46,7 +47,7 @@ func SetupRouter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 注册路由
|
// 注册路由
|
||||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
|
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -60,6 +61,7 @@ func registerRoutes(
|
|||||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||||
apiKeyService *service.APIKeyService,
|
apiKeyService *service.APIKeyService,
|
||||||
subscriptionService *service.SubscriptionService,
|
subscriptionService *service.SubscriptionService,
|
||||||
|
opsService *service.OpsService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
) {
|
) {
|
||||||
@@ -73,5 +75,5 @@ func registerRoutes(
|
|||||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ func RegisterAdminRoutes(
|
|||||||
// 系统设置
|
// 系统设置
|
||||||
registerSettingsRoutes(admin, h)
|
registerSettingsRoutes(admin, h)
|
||||||
|
|
||||||
|
// 运维监控(Ops)
|
||||||
|
registerOpsRoutes(admin, h)
|
||||||
|
|
||||||
// 系统管理
|
// 系统管理
|
||||||
registerSystemRoutes(admin, h)
|
registerSystemRoutes(admin, h)
|
||||||
|
|
||||||
@@ -64,6 +67,66 @@ func RegisterAdminRoutes(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
|
ops := admin.Group("/ops")
|
||||||
|
{
|
||||||
|
// Realtime ops signals
|
||||||
|
ops.GET("/concurrency", h.Admin.Ops.GetConcurrencyStats)
|
||||||
|
ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability)
|
||||||
|
ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary)
|
||||||
|
|
||||||
|
// Alerts (rules + events)
|
||||||
|
ops.GET("/alert-rules", h.Admin.Ops.ListAlertRules)
|
||||||
|
ops.POST("/alert-rules", h.Admin.Ops.CreateAlertRule)
|
||||||
|
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
|
||||||
|
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
|
||||||
|
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
|
||||||
|
|
||||||
|
// Email notification config (DB-backed)
|
||||||
|
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
|
||||||
|
ops.PUT("/email-notification/config", h.Admin.Ops.UpdateEmailNotificationConfig)
|
||||||
|
|
||||||
|
// Runtime settings (DB-backed)
|
||||||
|
runtime := ops.Group("/runtime")
|
||||||
|
{
|
||||||
|
runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings)
|
||||||
|
runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advanced settings (DB-backed)
|
||||||
|
ops.GET("/advanced-settings", h.Admin.Ops.GetAdvancedSettings)
|
||||||
|
ops.PUT("/advanced-settings", h.Admin.Ops.UpdateAdvancedSettings)
|
||||||
|
|
||||||
|
// Settings group (DB-backed)
|
||||||
|
settings := ops.Group("/settings")
|
||||||
|
{
|
||||||
|
settings.GET("/metric-thresholds", h.Admin.Ops.GetMetricThresholds)
|
||||||
|
settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket realtime (QPS/TPS)
|
||||||
|
ws := ops.Group("/ws")
|
||||||
|
{
|
||||||
|
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error logs (MVP-1)
|
||||||
|
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
|
||||||
|
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
|
||||||
|
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
|
||||||
|
|
||||||
|
// Request drilldown (success + error)
|
||||||
|
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
|
||||||
|
|
||||||
|
// Dashboard (vNext - raw path for MVP)
|
||||||
|
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
|
||||||
|
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
|
||||||
|
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
|
||||||
|
ops.GET("/dashboard/error-trend", h.Admin.Ops.GetDashboardErrorTrend)
|
||||||
|
ops.GET("/dashboard/error-distribution", h.Admin.Ops.GetDashboardErrorDistribution)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
dashboard := admin.Group("/dashboard")
|
dashboard := admin.Group("/dashboard")
|
||||||
{
|
{
|
||||||
@@ -75,6 +138,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||||
|
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,6 +291,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
|
||||||
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
|
||||||
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
|
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
|
||||||
|
// 流超时处理配置
|
||||||
|
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||||
|
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,8 +27,10 @@ func RegisterAuthRoutes(
|
|||||||
auth.POST("/register", h.Auth.Register)
|
auth.POST("/register", h.Auth.Register)
|
||||||
auth.POST("/login", h.Auth.Login)
|
auth.POST("/login", h.Auth.Login)
|
||||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次
|
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||||
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
|
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.ValidatePromoCode)
|
||||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,13 +16,18 @@ func RegisterGatewayRoutes(
|
|||||||
apiKeyAuth middleware.APIKeyAuthMiddleware,
|
apiKeyAuth middleware.APIKeyAuthMiddleware,
|
||||||
apiKeyService *service.APIKeyService,
|
apiKeyService *service.APIKeyService,
|
||||||
subscriptionService *service.SubscriptionService,
|
subscriptionService *service.SubscriptionService,
|
||||||
|
opsService *service.OpsService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) {
|
) {
|
||||||
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
||||||
|
clientRequestID := middleware.ClientRequestID()
|
||||||
|
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||||
|
|
||||||
// API网关(Claude API兼容)
|
// API网关(Claude API兼容)
|
||||||
gateway := r.Group("/v1")
|
gateway := r.Group("/v1")
|
||||||
gateway.Use(bodyLimit)
|
gateway.Use(bodyLimit)
|
||||||
|
gateway.Use(clientRequestID)
|
||||||
|
gateway.Use(opsErrorLogger)
|
||||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||||
{
|
{
|
||||||
gateway.POST("/messages", h.Gateway.Messages)
|
gateway.POST("/messages", h.Gateway.Messages)
|
||||||
@@ -36,6 +41,8 @@ func RegisterGatewayRoutes(
|
|||||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||||
gemini := r.Group("/v1beta")
|
gemini := r.Group("/v1beta")
|
||||||
gemini.Use(bodyLimit)
|
gemini.Use(bodyLimit)
|
||||||
|
gemini.Use(clientRequestID)
|
||||||
|
gemini.Use(opsErrorLogger)
|
||||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||||
{
|
{
|
||||||
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||||
@@ -45,7 +52,7 @@ func RegisterGatewayRoutes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI Responses API(不带v1前缀的别名)
|
// OpenAI Responses API(不带v1前缀的别名)
|
||||||
r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||||
|
|
||||||
// Antigravity 模型列表
|
// Antigravity 模型列表
|
||||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
|
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
|
||||||
@@ -53,6 +60,8 @@ func RegisterGatewayRoutes(
|
|||||||
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
|
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
|
||||||
antigravityV1 := r.Group("/antigravity/v1")
|
antigravityV1 := r.Group("/antigravity/v1")
|
||||||
antigravityV1.Use(bodyLimit)
|
antigravityV1.Use(bodyLimit)
|
||||||
|
antigravityV1.Use(clientRequestID)
|
||||||
|
antigravityV1.Use(opsErrorLogger)
|
||||||
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||||
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||||
{
|
{
|
||||||
@@ -64,6 +73,8 @@ func RegisterGatewayRoutes(
|
|||||||
|
|
||||||
antigravityV1Beta := r.Group("/antigravity/v1beta")
|
antigravityV1Beta := r.Group("/antigravity/v1beta")
|
||||||
antigravityV1Beta.Use(bodyLimit)
|
antigravityV1Beta.Use(bodyLimit)
|
||||||
|
antigravityV1Beta.Use(clientRequestID)
|
||||||
|
antigravityV1Beta.Use(opsErrorLogger)
|
||||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -186,9 +186,11 @@ type BulkUpdateAccountResult struct {
|
|||||||
|
|
||||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||||
type BulkUpdateAccountsResult struct {
|
type BulkUpdateAccountsResult struct {
|
||||||
Success int `json:"success"`
|
Success int `json:"success"`
|
||||||
Failed int `json:"failed"`
|
Failed int `json:"failed"`
|
||||||
Results []BulkUpdateAccountResult `json:"results"`
|
SuccessIDs []int64 `json:"success_ids"`
|
||||||
|
FailedIDs []int64 `json:"failed_ids"`
|
||||||
|
Results []BulkUpdateAccountResult `json:"results"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateProxyInput struct {
|
type CreateProxyInput struct {
|
||||||
@@ -244,14 +246,15 @@ type ProxyExitInfoProber interface {
|
|||||||
|
|
||||||
// adminServiceImpl implements AdminService
|
// adminServiceImpl implements AdminService
|
||||||
type adminServiceImpl struct {
|
type adminServiceImpl struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
proxyRepo ProxyRepository
|
proxyRepo ProxyRepository
|
||||||
apiKeyRepo APIKeyRepository
|
apiKeyRepo APIKeyRepository
|
||||||
redeemCodeRepo RedeemCodeRepository
|
redeemCodeRepo RedeemCodeRepository
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
proxyProber ProxyExitInfoProber
|
proxyProber ProxyExitInfoProber
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminService creates a new AdminService
|
// NewAdminService creates a new AdminService
|
||||||
@@ -264,16 +267,18 @@ func NewAdminService(
|
|||||||
redeemCodeRepo RedeemCodeRepository,
|
redeemCodeRepo RedeemCodeRepository,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
proxyProber ProxyExitInfoProber,
|
proxyProber ProxyExitInfoProber,
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
proxyRepo: proxyRepo,
|
proxyRepo: proxyRepo,
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
proxyProber: proxyProber,
|
proxyProber: proxyProber,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +328,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
}
|
}
|
||||||
|
|
||||||
oldConcurrency := user.Concurrency
|
oldConcurrency := user.Concurrency
|
||||||
|
oldStatus := user.Status
|
||||||
|
oldRole := user.Role
|
||||||
|
|
||||||
if input.Email != "" {
|
if input.Email != "" {
|
||||||
user.Email = input.Email
|
user.Email = input.Email
|
||||||
@@ -355,6 +362,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||||
if concurrencyDiff != 0 {
|
if concurrencyDiff != 0 {
|
||||||
@@ -393,6 +405,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
|||||||
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,6 +435,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
balanceDiff := user.Balance - oldBalance
|
||||||
|
if s.authCacheInvalidator != nil && balanceDiff != 0 {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
if s.billingCacheService != nil {
|
if s.billingCacheService != nil {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -431,7 +450,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
balanceDiff := user.Balance - oldBalance
|
|
||||||
if balanceDiff != 0 {
|
if balanceDiff != 0 {
|
||||||
code, err := GenerateRedeemCode()
|
code, err := GenerateRedeemCode()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -675,10 +693,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||||
|
var groupKeys []string
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
|
||||||
|
if err == nil {
|
||||||
|
groupKeys = keys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -697,6 +726,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
for _, key := range groupKeys {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -885,7 +919,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||||
result := &BulkUpdateAccountsResult{
|
result := &BulkUpdateAccountsResult{
|
||||||
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||||
|
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||||
|
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(input.AccountIDs) == 0 {
|
if len(input.AccountIDs) == 0 {
|
||||||
@@ -949,6 +985,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
entry.Success = false
|
entry.Success = false
|
||||||
entry.Error = err.Error()
|
entry.Error = err.Error()
|
||||||
result.Failed++
|
result.Failed++
|
||||||
|
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -958,6 +995,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
entry.Success = false
|
entry.Success = false
|
||||||
entry.Error = err.Error()
|
entry.Error = err.Error()
|
||||||
result.Failed++
|
result.Failed++
|
||||||
|
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -967,6 +1005,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
entry.Success = false
|
entry.Success = false
|
||||||
entry.Error = err.Error()
|
entry.Error = err.Error()
|
||||||
result.Failed++
|
result.Failed++
|
||||||
|
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -974,6 +1013,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
|
|
||||||
entry.Success = true
|
entry.Success = true
|
||||||
result.Success++
|
result.Success++
|
||||||
|
result.SuccessIDs = append(result.SuccessIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type accountRepoStubForBulkUpdate struct {
|
||||||
|
accountRepoStub
|
||||||
|
bulkUpdateErr error
|
||||||
|
bulkUpdateIDs []int64
|
||||||
|
bindGroupErrByID map[int64]error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||||
|
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||||
|
if s.bulkUpdateErr != nil {
|
||||||
|
return 0, s.bulkUpdateErr
|
||||||
|
}
|
||||||
|
return int64(len(ids)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
|
||||||
|
if err, ok := s.bindGroupErrByID[accountID]; ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||||
|
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||||
|
repo := &accountRepoStubForBulkUpdate{}
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
|
||||||
|
schedulable := true
|
||||||
|
input := &BulkUpdateAccountsInput{
|
||||||
|
AccountIDs: []int64{1, 2, 3},
|
||||||
|
Schedulable: &schedulable,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 3, result.Success)
|
||||||
|
require.Equal(t, 0, result.Failed)
|
||||||
|
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
|
||||||
|
require.Empty(t, result.FailedIDs)
|
||||||
|
require.Len(t, result.Results, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
|
||||||
|
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||||
|
repo := &accountRepoStubForBulkUpdate{
|
||||||
|
bindGroupErrByID: map[int64]error{
|
||||||
|
2: errors.New("bind failed"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
|
||||||
|
groupIDs := []int64{10}
|
||||||
|
schedulable := false
|
||||||
|
input := &BulkUpdateAccountsInput{
|
||||||
|
AccountIDs: []int64{1, 2, 3},
|
||||||
|
GroupIDs: &groupIDs,
|
||||||
|
Schedulable: &schedulable,
|
||||||
|
SkipMixedChannelCheck: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 2, result.Success)
|
||||||
|
require.Equal(t, 1, result.Failed)
|
||||||
|
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
|
||||||
|
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||||
|
require.Len(t, result.Results, 3)
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type balanceUserRepoStub struct {
|
||||||
|
*userRepoStub
|
||||||
|
updateErr error
|
||||||
|
updated []*User
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
|
||||||
|
if s.updateErr != nil {
|
||||||
|
return s.updateErr
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *user
|
||||||
|
s.updated = append(s.updated, &clone)
|
||||||
|
if s.userRepoStub != nil {
|
||||||
|
s.userRepoStub.user = &clone
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type balanceRedeemRepoStub struct {
|
||||||
|
*redeemRepoStub
|
||||||
|
created []*RedeemCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||||
|
if code == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *code
|
||||||
|
s.created = append(s.created, &clone)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type authCacheInvalidatorStub struct {
|
||||||
|
userIDs []int64
|
||||||
|
groupIDs []int64
|
||||||
|
keys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||||
|
s.keys = append(s.keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||||
|
s.userIDs = append(s.userIDs, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||||
|
s.groupIDs = append(s.groupIDs, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
|
||||||
|
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||||
|
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||||
|
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: repo,
|
||||||
|
redeemCodeRepo: redeemRepo,
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||||
|
require.Len(t, redeemRepo.created, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
|
||||||
|
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||||
|
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||||
|
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: repo,
|
||||||
|
redeemCodeRepo: redeemRepo,
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, invalidator.userIDs)
|
||||||
|
require.Empty(t, redeemRepo.created)
|
||||||
|
}
|
||||||
@@ -523,6 +523,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanitize thinking blocks (clean cache_control and flatten history thinking)
|
||||||
|
sanitizeThinkingBlocks(&claudeReq)
|
||||||
|
|
||||||
// 获取转换选项
|
// 获取转换选项
|
||||||
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
||||||
transformOpts := s.getClaudeTransformOptions(ctx)
|
transformOpts := s.getClaudeTransformOptions(ctx)
|
||||||
@@ -534,6 +537,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return nil, fmt.Errorf("transform request: %w", err)
|
return nil, fmt.Errorf("transform request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Safety net: ensure no cache_control leaked into Gemini request
|
||||||
|
geminiBody = cleanCacheControlFromGeminiJSON(geminiBody)
|
||||||
|
|
||||||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||||||
action := "streamGenerateContent"
|
action := "streamGenerateContent"
|
||||||
@@ -564,6 +570,14 @@ urlFallbackLoop:
|
|||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
// 检查是否应触发 URL 降级
|
// 检查是否应触发 URL 降级
|
||||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||||
@@ -579,6 +593,7 @@ urlFallbackLoop:
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,6 +601,26 @@ urlFallbackLoop:
|
|||||||
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
|
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "retry",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||||
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
||||||
continue urlFallbackLoop
|
continue urlFallbackLoop
|
||||||
@@ -596,6 +631,26 @@ urlFallbackLoop:
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
if attempt < antigravityMaxRetries {
|
if attempt < antigravityMaxRetries {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "retry",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||||
@@ -628,6 +683,27 @@ urlFallbackLoop:
|
|||||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "signature_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
// Conservative two-stage fallback:
|
// Conservative two-stage fallback:
|
||||||
// 1) Disable top-level thinking + thinking->text
|
// 1) Disable top-level thinking + thinking->text
|
||||||
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
|
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
|
||||||
@@ -661,6 +737,13 @@ urlFallbackLoop:
|
|||||||
}
|
}
|
||||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if retryErr != nil {
|
if retryErr != nil {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "signature_retry_request_error",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||||
|
})
|
||||||
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
|
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -674,6 +757,25 @@ urlFallbackLoop:
|
|||||||
|
|
||||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||||
_ = retryResp.Body.Close()
|
_ = retryResp.Body.Close()
|
||||||
|
kind := "signature_retry"
|
||||||
|
if strings.TrimSpace(stage.name) != "" {
|
||||||
|
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
|
||||||
|
}
|
||||||
|
retryUpstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(retryBody))
|
||||||
|
retryUpstreamMsg = sanitizeUpstreamErrorMessage(retryUpstreamMsg)
|
||||||
|
retryUpstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
retryUpstreamDetail = truncateString(string(retryBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: retryResp.StatusCode,
|
||||||
|
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||||
|
Kind: kind,
|
||||||
|
Message: retryUpstreamMsg,
|
||||||
|
Detail: retryUpstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
|
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
|
||||||
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
|
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
|
||||||
@@ -701,10 +803,30 @@ urlFallbackLoop:
|
|||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -787,6 +909,143 @@ func extractAntigravityErrorMessage(body []byte) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
|
||||||
|
// This should not be needed if transformation is correct, but serves as a safety net
|
||||||
|
func cleanCacheControlFromGeminiJSON(body []byte) []byte {
|
||||||
|
// Try a more robust approach: parse and clean
|
||||||
|
var data map[string]any
|
||||||
|
if err := json.Unmarshal(body, &data); err != nil {
|
||||||
|
log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned := removeCacheControlFromAny(data)
|
||||||
|
if !cleaned {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if result, err := json.Marshal(data); err == nil {
|
||||||
|
log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON")
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeCacheControlFromAny recursively removes cache_control fields
|
||||||
|
func removeCacheControlFromAny(v any) bool {
|
||||||
|
cleaned := false
|
||||||
|
|
||||||
|
switch val := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
for k, child := range val {
|
||||||
|
if k == "cache_control" {
|
||||||
|
delete(val, k)
|
||||||
|
cleaned = true
|
||||||
|
} else if removeCacheControlFromAny(child) {
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, item := range val {
|
||||||
|
if removeCacheControlFromAny(item) {
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
|
||||||
|
// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
|
||||||
|
// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
|
||||||
|
func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) {
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages))
|
||||||
|
|
||||||
|
// Clean system blocks
|
||||||
|
if len(req.System) > 0 {
|
||||||
|
var systemBlocks []map[string]any
|
||||||
|
if err := json.Unmarshal(req.System, &systemBlocks); err == nil {
|
||||||
|
for i := range systemBlocks {
|
||||||
|
if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil {
|
||||||
|
if removeCacheControlFromAny(systemBlocks[i]) {
|
||||||
|
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Marshal back
|
||||||
|
if cleaned, err := json.Marshal(systemBlocks); err == nil {
|
||||||
|
req.System = cleaned
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean message content blocks and flatten history
|
||||||
|
lastMsgIdx := len(req.Messages) - 1
|
||||||
|
for msgIdx := range req.Messages {
|
||||||
|
raw := req.Messages[msgIdx].Content
|
||||||
|
if len(raw) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as blocks array
|
||||||
|
var blocks []map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned := false
|
||||||
|
for blockIdx := range blocks {
|
||||||
|
blockType, _ := blocks[blockIdx]["type"].(string)
|
||||||
|
|
||||||
|
// Check for thinking blocks (typed or untyped)
|
||||||
|
if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil {
|
||||||
|
// 1. Clean cache_control
|
||||||
|
if removeCacheControlFromAny(blocks[blockIdx]) {
|
||||||
|
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx)
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Flatten to text if it's a history message (not the last one)
|
||||||
|
if msgIdx < lastMsgIdx {
|
||||||
|
log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx)
|
||||||
|
|
||||||
|
// Extract thinking content
|
||||||
|
var textContent string
|
||||||
|
if t, ok := blocks[blockIdx]["thinking"].(string); ok {
|
||||||
|
textContent = t
|
||||||
|
} else {
|
||||||
|
// Fallback for non-string content (marshal it)
|
||||||
|
if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil {
|
||||||
|
textContent = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to text block
|
||||||
|
blocks[blockIdx]["type"] = "text"
|
||||||
|
blocks[blockIdx]["text"] = textContent
|
||||||
|
delete(blocks[blockIdx], "thinking")
|
||||||
|
delete(blocks[blockIdx], "signature")
|
||||||
|
delete(blocks[blockIdx], "cache_control") // Ensure it's gone
|
||||||
|
cleaned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal back if modified
|
||||||
|
if cleaned {
|
||||||
|
if marshaled, err := json.Marshal(blocks); err == nil {
|
||||||
|
req.Messages[msgIdx].Content = marshaled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||||||
// This preserves the thinking content while avoiding signature validation errors.
|
// This preserves the thinking content while avoiding signature validation errors.
|
||||||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||||||
@@ -1108,6 +1367,14 @@ urlFallbackLoop:
|
|||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
// 检查是否应触发 URL 降级
|
// 检查是否应触发 URL 降级
|
||||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||||
@@ -1123,6 +1390,7 @@ urlFallbackLoop:
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1130,6 +1398,26 @@ urlFallbackLoop:
|
|||||||
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
|
if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "retry",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||||
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
||||||
continue urlFallbackLoop
|
continue urlFallbackLoop
|
||||||
@@ -1140,6 +1428,26 @@ urlFallbackLoop:
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
if attempt < antigravityMaxRetries {
|
if attempt < antigravityMaxRetries {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "retry",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||||
@@ -1205,21 +1513,59 @@ urlFallbackLoop:
|
|||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解包并返回错误
|
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
if requestID != "" {
|
if requestID != "" {
|
||||||
c.Header("x-request-id", requestID)
|
c.Header("x-request-id", requestID)
|
||||||
}
|
}
|
||||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
|
||||||
|
unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody)
|
||||||
|
unwrappedForOps := unwrapped
|
||||||
|
if unwrapErr != nil || len(unwrappedForOps) == 0 {
|
||||||
|
unwrappedForOps = respBody
|
||||||
|
}
|
||||||
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(unwrappedForOps), maxBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always record upstream context for Ops error logs, even when we will failover.
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
|
||||||
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: requestID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
|
}
|
||||||
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
if contentType == "" {
|
if contentType == "" {
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
}
|
}
|
||||||
c.Data(resp.StatusCode, contentType, unwrapped)
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: requestID,
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
c.Data(resp.StatusCode, contentType, unwrappedForOps)
|
||||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1514,6 +1860,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("Stream data interval timeout (antigravity)")
|
log.Printf("Stream data interval timeout (antigravity)")
|
||||||
|
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
}
|
}
|
||||||
@@ -1674,9 +2021,35 @@ func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int,
|
|||||||
return fmt.Errorf("%s", message)
|
return fmt.Errorf("%s", message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error {
|
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||||
// 记录上游错误详情便于调试
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, string(body))
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
|
||||||
|
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||||
|
maxBytes := 2048
|
||||||
|
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||||
|
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamDetail := ""
|
||||||
|
if logBody {
|
||||||
|
upstreamDetail = truncateString(string(body), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: upstreamStatus,
|
||||||
|
UpstreamRequestID: upstreamRequestID,
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
|
||||||
|
if logBody {
|
||||||
|
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
||||||
|
}
|
||||||
|
|
||||||
var statusCode int
|
var statusCode int
|
||||||
var errType, errMsg string
|
var errType, errMsg string
|
||||||
@@ -1712,7 +2085,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstr
|
|||||||
"type": "error",
|
"type": "error",
|
||||||
"error": gin.H{"type": errType, "message": errMsg},
|
"error": gin.H{"type": errType, "message": errMsg},
|
||||||
})
|
})
|
||||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
if upstreamMsg == "" {
|
||||||
|
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||||
@@ -2039,6 +2415,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("Stream data interval timeout (antigravity)")
|
log.Printf("Stream data interval timeout (antigravity)")
|
||||||
|
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
}
|
}
|
||||||
|
|||||||
46
backend/internal/service/api_key_auth_cache.go
Normal file
46
backend/internal/service/api_key_auth_cache.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||||
|
type APIKeyAuthSnapshot struct {
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||||
|
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||||
|
User APIKeyAuthUserSnapshot `json:"user"`
|
||||||
|
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthUserSnapshot 用户快照
|
||||||
|
type APIKeyAuthUserSnapshot struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Balance float64 `json:"balance"`
|
||||||
|
Concurrency int `json:"concurrency"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
|
type APIKeyAuthGroupSnapshot struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
SubscriptionType string `json:"subscription_type"`
|
||||||
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
|
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||||
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||||
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||||
|
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||||
|
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||||
|
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||||
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||||
|
type APIKeyAuthCacheEntry struct {
|
||||||
|
NotFound bool `json:"not_found"`
|
||||||
|
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
|
||||||
|
}
|
||||||
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/dgraph-io/ristretto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiKeyAuthCacheConfig struct {
|
||||||
|
l1Size int
|
||||||
|
l1TTL time.Duration
|
||||||
|
l2TTL time.Duration
|
||||||
|
negativeTTL time.Duration
|
||||||
|
jitterPercent int
|
||||||
|
singleflight bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
jitterRandMu sync.Mutex
|
||||||
|
// 认证缓存抖动使用独立随机源,避免全局 Seed
|
||||||
|
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||||
|
if cfg == nil {
|
||||||
|
return apiKeyAuthCacheConfig{}
|
||||||
|
}
|
||||||
|
auth := cfg.APIKeyAuth
|
||||||
|
return apiKeyAuthCacheConfig{
|
||||||
|
l1Size: auth.L1Size,
|
||||||
|
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
|
||||||
|
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
|
||||||
|
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
|
||||||
|
jitterPercent: auth.JitterPercent,
|
||||||
|
singleflight: auth.Singleflight,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
|
||||||
|
return c.l1Size > 0 && c.l1TTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
|
||||||
|
return c.l2TTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||||
|
return c.negativeTTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||||
|
if ttl <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
if c.jitterPercent <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
percent := c.jitterPercent
|
||||||
|
if percent > 100 {
|
||||||
|
percent = 100
|
||||||
|
}
|
||||||
|
delta := float64(percent) / 100
|
||||||
|
jitterRandMu.Lock()
|
||||||
|
randVal := jitterRand.Float64()
|
||||||
|
jitterRandMu.Unlock()
|
||||||
|
factor := 1 - delta + randVal*(2*delta)
|
||||||
|
if factor <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
return time.Duration(float64(ttl) * factor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||||
|
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
|
||||||
|
if !s.authCfg.l1Enabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||||
|
NumCounters: int64(s.authCfg.l1Size) * 10,
|
||||||
|
MaxCost: int64(s.authCfg.l1Size),
|
||||||
|
BufferItems: 64,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.authCacheL1 = cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) authCacheKey(key string) string {
|
||||||
|
sum := sha256.Sum256([]byte(key))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
|
||||||
|
if s.authCacheL1 != nil {
|
||||||
|
if val, ok := s.authCacheL1.Get(cacheKey); ok {
|
||||||
|
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
s.setAuthCacheL1(cacheKey, entry)
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
|
||||||
|
if s.authCacheL1 == nil || entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ttl := s.authCfg.l1TTL
|
||||||
|
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
|
||||||
|
ttl = s.authCfg.negativeTTL
|
||||||
|
}
|
||||||
|
ttl = s.authCfg.jitterTTL(ttl)
|
||||||
|
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
|
||||||
|
if entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.setAuthCacheL1(cacheKey, entry)
|
||||||
|
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||||
|
if s.authCacheL1 != nil {
|
||||||
|
s.authCacheL1.Del(cacheKey)
|
||||||
|
}
|
||||||
|
if s.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrAPIKeyNotFound) {
|
||||||
|
entry := &APIKeyAuthCacheEntry{NotFound: true}
|
||||||
|
if s.authCfg.negativeEnabled() {
|
||||||
|
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
|
||||||
|
}
|
||||||
|
return entry, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
apiKey.Key = key
|
||||||
|
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||||
|
if snapshot == nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||||
|
}
|
||||||
|
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
|
||||||
|
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
|
||||||
|
return entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
|
||||||
|
if entry == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if entry.NotFound {
|
||||||
|
return nil, true, ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
if entry.Snapshot == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||||
|
if apiKey == nil || apiKey.User == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
snapshot := &APIKeyAuthSnapshot{
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: apiKey.UserID,
|
||||||
|
GroupID: apiKey.GroupID,
|
||||||
|
Status: apiKey.Status,
|
||||||
|
IPWhitelist: apiKey.IPWhitelist,
|
||||||
|
IPBlacklist: apiKey.IPBlacklist,
|
||||||
|
User: APIKeyAuthUserSnapshot{
|
||||||
|
ID: apiKey.User.ID,
|
||||||
|
Status: apiKey.User.Status,
|
||||||
|
Role: apiKey.User.Role,
|
||||||
|
Balance: apiKey.User.Balance,
|
||||||
|
Concurrency: apiKey.User.Concurrency,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||||
|
ID: apiKey.Group.ID,
|
||||||
|
Name: apiKey.Group.Name,
|
||||||
|
Platform: apiKey.Group.Platform,
|
||||||
|
Status: apiKey.Group.Status,
|
||||||
|
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||||
|
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||||
|
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||||
|
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||||
|
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||||
|
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||||
|
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||||
|
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
|
||||||
|
if snapshot == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiKey := &APIKey{
|
||||||
|
ID: snapshot.APIKeyID,
|
||||||
|
UserID: snapshot.UserID,
|
||||||
|
GroupID: snapshot.GroupID,
|
||||||
|
Key: key,
|
||||||
|
Status: snapshot.Status,
|
||||||
|
IPWhitelist: snapshot.IPWhitelist,
|
||||||
|
IPBlacklist: snapshot.IPBlacklist,
|
||||||
|
User: &User{
|
||||||
|
ID: snapshot.User.ID,
|
||||||
|
Status: snapshot.User.Status,
|
||||||
|
Role: snapshot.User.Role,
|
||||||
|
Balance: snapshot.User.Balance,
|
||||||
|
Concurrency: snapshot.User.Concurrency,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if snapshot.Group != nil {
|
||||||
|
apiKey.Group = &Group{
|
||||||
|
ID: snapshot.Group.ID,
|
||||||
|
Name: snapshot.Group.Name,
|
||||||
|
Platform: snapshot.Group.Platform,
|
||||||
|
Status: snapshot.Group.Status,
|
||||||
|
Hydrated: true,
|
||||||
|
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||||
|
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||||
|
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||||
|
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||||
|
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||||
|
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||||
|
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||||
|
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return apiKey
|
||||||
|
}
|
||||||
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||||
|
if key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cacheKey := s.authCacheKey(key)
|
||||||
|
s.deleteAuthCache(ctx, cacheKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||||
|
if userID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.deleteAuthCacheByKeys(ctx, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||||
|
if groupID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.deleteAuthCacheByKeys(ctx, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.deleteAuthCache(ctx, s.authCacheKey(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/dgraph-io/ristretto"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -31,9 +33,11 @@ const (
|
|||||||
type APIKeyRepository interface {
|
type APIKeyRepository interface {
|
||||||
Create(ctx context.Context, key *APIKey) error
|
Create(ctx context.Context, key *APIKey) error
|
||||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景
|
||||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
|
||||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||||
|
// GetByKeyForAuth 认证专用查询,返回最小字段集
|
||||||
|
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
|
||||||
Update(ctx context.Context, key *APIKey) error
|
Update(ctx context.Context, key *APIKey) error
|
||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
@@ -45,6 +49,8 @@ type APIKeyRepository interface {
|
|||||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
|
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||||
|
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyCache defines cache operations for API key service
|
// APIKeyCache defines cache operations for API key service
|
||||||
@@ -55,6 +61,17 @@ type APIKeyCache interface {
|
|||||||
|
|
||||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||||
|
|
||||||
|
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||||
|
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||||
|
DeleteAuthCache(ctx context.Context, key string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||||
|
type APIKeyAuthCacheInvalidator interface {
|
||||||
|
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||||
|
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
|
||||||
|
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAPIKeyRequest 创建API Key请求
|
// CreateAPIKeyRequest 创建API Key请求
|
||||||
@@ -83,6 +100,9 @@ type APIKeyService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache APIKeyCache
|
cache APIKeyCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
authCacheL1 *ristretto.Cache
|
||||||
|
authCfg apiKeyAuthCacheConfig
|
||||||
|
authGroup singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIKeyService 创建API Key服务实例
|
// NewAPIKeyService 创建API Key服务实例
|
||||||
@@ -94,7 +114,7 @@ func NewAPIKeyService(
|
|||||||
cache APIKeyCache,
|
cache APIKeyCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *APIKeyService {
|
) *APIKeyService {
|
||||||
return &APIKeyService{
|
svc := &APIKeyService{
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
@@ -102,6 +122,8 @@ func NewAPIKeyService(
|
|||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
|
svc.initAuthCache(cfg)
|
||||||
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateKey 生成随机API Key
|
// GenerateKey 生成随机API Key
|
||||||
@@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
|||||||
return nil, fmt.Errorf("create api key: %w", err)
|
return nil, fmt.Errorf("create api key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
|
|||||||
|
|
||||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
// 尝试从Redis缓存获取
|
cacheKey := s.authCacheKey(key)
|
||||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
|
||||||
|
|
||||||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
|
||||||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.authCfg.singleflight {
|
||||||
|
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
|
||||||
|
return s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
entry, _ := value.(*APIKeyAuthCacheEntry)
|
||||||
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get api key: %w", err)
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
}
|
}
|
||||||
|
apiKey.Key = key
|
||||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
|
||||||
if s.cache != nil {
|
|
||||||
// 这里可以序列化并缓存API Key
|
|
||||||
_ = cacheKey // 使用变量避免未使用错误
|
|
||||||
}
|
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
|||||||
return nil, fmt.Errorf("update api key: %w", err)
|
return nil, fmt.Errorf("update api key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete 删除API Key
|
// Delete 删除API Key
|
||||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
|
||||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
|
||||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
|
||||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get api key: %w", err)
|
return fmt.Errorf("get api key: %w", err)
|
||||||
}
|
}
|
||||||
@@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
|||||||
return ErrInsufficientPerms
|
return ErrInsufficientPerms
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清除Redis缓存(使用 ownerID 而非 apiKey.UserID)
|
// 清除Redis缓存(使用 userID 而非 apiKey.UserID)
|
||||||
if s.cache != nil {
|
if s.cache != nil {
|
||||||
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
|
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
|
||||||
}
|
}
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, key)
|
||||||
|
|
||||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||||
return fmt.Errorf("delete api key: %w", err)
|
return fmt.Errorf("delete api key: %w", err)
|
||||||
|
|||||||
417
backend/internal/service/api_key_service_cache_test.go
Normal file
417
backend/internal/service/api_key_service_cache_test.go
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authRepoStub struct {
|
||||||
|
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
|
||||||
|
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
|
||||||
|
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||||
|
panic("unexpected Create call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
|
panic("unexpected GetKeyAndOwnerID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKey call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
if s.getByKeyForAuth == nil {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
return s.getByKeyForAuth(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||||
|
panic("unexpected Update call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||||
|
panic("unexpected VerifyOwnership call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||||
|
panic("unexpected CountByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByKey call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||||
|
panic("unexpected SearchAPIKeys call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
|
panic("unexpected ClearGroupIDByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
|
panic("unexpected CountByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
if s.listKeysByUserID == nil {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
return s.listKeysByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
if s.listKeysByGroupID == nil {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
return s.listKeysByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
type authCacheStub struct {
|
||||||
|
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||||
|
setAuthKeys []string
|
||||||
|
deleteAuthKeys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
if s.getAuthCache == nil {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
return s.getAuthCache(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
s.setAuthKeys = append(s.setAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, errors.New("unexpected repo call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
groupID := int64(9)
|
||||||
|
cacheEntry := &APIKeyAuthCacheEntry{
|
||||||
|
Snapshot: &APIKeyAuthSnapshot{
|
||||||
|
APIKeyID: 1,
|
||||||
|
UserID: 2,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: APIKeyAuthUserSnapshot{
|
||||||
|
ID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 10,
|
||||||
|
Concurrency: 3,
|
||||||
|
},
|
||||||
|
Group: &APIKeyAuthGroupSnapshot{
|
||||||
|
ID: groupID,
|
||||||
|
Name: "g",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
RateMultiplier: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return cacheEntry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := svc.GetByKey(context.Background(), "k1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(1), apiKey.ID)
|
||||||
|
require.Equal(t, int64(2), apiKey.User.ID)
|
||||||
|
require.Equal(t, groupID, apiKey.Group.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, errors.New("unexpected repo call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "missing")
|
||||||
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return &APIKey{
|
||||||
|
ID: 5,
|
||||||
|
UserID: 7,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 7,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 12,
|
||||||
|
Concurrency: 2,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := svc.GetByKey(context.Background(), "k2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(5), apiKey.ID)
|
||||||
|
require.Len(t, cache.setAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return &APIKey{
|
||||||
|
ID: 21,
|
||||||
|
UserID: 3,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 3,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 5,
|
||||||
|
Concurrency: 2,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L1Size: 1000,
|
||||||
|
L1TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
require.NotNil(t, svc.authCacheL1)
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
svc.authCacheL1.Wait()
|
||||||
|
cacheKey := svc.authCacheKey("k-l1")
|
||||||
|
_, ok := svc.authCacheL1.Get(cacheKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
_, err = svc.GetByKey(context.Background(), "k-l1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return []string{"k1", "k2"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return []string{"k1", "k2"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, ErrAPIKeyNotFound
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "missing")
|
||||||
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
|
require.Len(t, cache.setAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
return &APIKey{
|
||||||
|
ID: 11,
|
||||||
|
UserID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 1,
|
||||||
|
Concurrency: 1,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
Singleflight: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
start := make(chan struct{})
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
errs := make([]error, 5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
_, err := svc.GetByKey(context.Background(), "k1")
|
||||||
|
errs[idx] = err
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for _, err := range errs {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||||
|
}
|
||||||
@@ -20,13 +20,12 @@ import (
|
|||||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||||
//
|
//
|
||||||
// 设计说明:
|
// 设计说明:
|
||||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
|
||||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
|
||||||
// - deleteErr: 模拟 Delete 返回的错误
|
// - deleteErr: 模拟 Delete 返回的错误
|
||||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||||
type apiKeyRepoStub struct {
|
type apiKeyRepoStub struct {
|
||||||
ownerID int64 // GetOwnerID 的返回值
|
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||||
ownerErr error // GetOwnerID 的错误返回值
|
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||||
deleteErr error // Delete 的错误返回值
|
deleteErr error // Delete 的错误返回值
|
||||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||||
}
|
}
|
||||||
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||||
|
if s.getByIDErr != nil {
|
||||||
|
return nil, s.getByIDErr
|
||||||
|
}
|
||||||
|
if s.apiKey != nil {
|
||||||
|
clone := *s.apiKey
|
||||||
|
return &clone, nil
|
||||||
|
}
|
||||||
panic("unexpected GetByID call")
|
panic("unexpected GetByID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOwnerID 返回预设的所有者 ID 或错误。
|
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
|
if s.getByIDErr != nil {
|
||||||
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
return "", 0, s.getByIDErr
|
||||||
return s.ownerID, s.ownerErr
|
}
|
||||||
|
if s.apiKey != nil {
|
||||||
|
return s.apiKey.Key, s.apiKey.UserID, nil
|
||||||
|
}
|
||||||
|
return "", 0, ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
panic("unexpected GetByKey call")
|
panic("unexpected GetByKey call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||||
panic("unexpected Update call")
|
panic("unexpected Update call")
|
||||||
}
|
}
|
||||||
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
panic("unexpected CountByGroupID call")
|
panic("unexpected CountByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||||
//
|
//
|
||||||
// 设计说明:
|
// 设计说明:
|
||||||
// - invalidated: 记录被清除缓存的用户 ID 列表
|
// - invalidated: 记录被清除缓存的用户 ID 列表
|
||||||
type apiKeyCacheStub struct {
|
type apiKeyCacheStub struct {
|
||||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||||
|
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
||||||
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回所有者 ID 为 1
|
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||||
// - 调用者 userID 为 2(不匹配)
|
// - 调用者 userID 为 2(不匹配)
|
||||||
// - 返回 ErrInsufficientPerms 错误
|
// - 返回 ErrInsufficientPerms 错误
|
||||||
// - Delete 方法不被调用
|
// - Delete 方法不被调用
|
||||||
// - 缓存不被清除
|
// - 缓存不被清除
|
||||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerID: 1}
|
repo := &apiKeyRepoStub{
|
||||||
|
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
|
||||||
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||||
|
require.Empty(t, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回所有者 ID 为 7
|
// - GetKeyAndOwnerID 返回所有者 ID 为 7
|
||||||
// - 调用者 userID 为 7(匹配)
|
// - 调用者 userID 为 7(匹配)
|
||||||
// - Delete 成功执行
|
// - Delete 成功执行
|
||||||
// - 缓存被正确清除(使用 ownerID)
|
// - 缓存被正确清除(使用 ownerID)
|
||||||
// - 返回 nil 错误
|
// - 返回 nil 错误
|
||||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerID: 7}
|
repo := &apiKeyRepoStub{
|
||||||
|
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
|
||||||
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
||||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||||
// - Delete 方法不被调用
|
// - Delete 方法不被调用
|
||||||
// - 缓存不被清除
|
// - 缓存不被清除
|
||||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
require.Empty(t, repo.deletedIDs)
|
require.Empty(t, repo.deletedIDs)
|
||||||
require.Empty(t, cache.invalidated)
|
require.Empty(t, cache.invalidated)
|
||||||
|
require.Empty(t, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回正确的所有者 ID
|
// - GetKeyAndOwnerID 返回正确的所有者 ID
|
||||||
// - 所有权验证通过
|
// - 所有权验证通过
|
||||||
// - 缓存被清除(在删除之前)
|
// - 缓存被清除(在删除之前)
|
||||||
// - Delete 被调用但返回错误
|
// - Delete 被调用但返回错误
|
||||||
// - 返回包含 "delete api key" 的错误信息
|
// - 返回包含 "delete api key" 的错误信息
|
||||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{
|
repo := &apiKeyRepoStub{
|
||||||
ownerID: 3,
|
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
|
||||||
deleteErr: errors.New("delete failed"),
|
deleteErr: errors.New("delete failed"),
|
||||||
}
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
|||||||
require.ErrorContains(t, err, "delete api key")
|
require.ErrorContains(t, err, "delete api key")
|
||||||
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
||||||
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|||||||
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &UsageService{authCacheInvalidator: invalidator}
|
||||||
|
|
||||||
|
svc.invalidateUsageCaches(context.Background(), 7, false)
|
||||||
|
require.Empty(t, invalidator.userIDs)
|
||||||
|
|
||||||
|
svc.invalidateUsageCaches(context.Background(), 7, true)
|
||||||
|
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &RedeemService{authCacheInvalidator: invalidator}
|
||||||
|
|
||||||
|
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
|
||||||
|
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
|
||||||
|
groupID := int64(3)
|
||||||
|
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
|
||||||
|
|
||||||
|
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
|
||||||
|
}
|
||||||
@@ -357,7 +357,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
|||||||
// - 如果邮箱已存在:直接登录(不需要本地密码)
|
// - 如果邮箱已存在:直接登录(不需要本地密码)
|
||||||
// - 如果邮箱不存在:创建新用户并登录
|
// - 如果邮箱不存在:创建新用户并登录
|
||||||
//
|
//
|
||||||
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
|
// 注意:该函数用于 LinuxDo OAuth 登录场景(不同于上游账号的 OAuth,例如 Claude/OpenAI/Gemini)。
|
||||||
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
|
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
|
||||||
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
|
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
|
||||||
email = strings.TrimSpace(email)
|
email = strings.TrimSpace(email)
|
||||||
@@ -376,8 +376,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
|||||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUserNotFound) {
|
if errors.Is(err, ErrUserNotFound) {
|
||||||
// OAuth 首次登录视为注册。
|
// OAuth 首次登录视为注册(fail-close:settingService 未配置时不允许注册)
|
||||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||||
return "", nil, ErrRegDisabled
|
return "", nil, ErrRegDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
258
backend/internal/service/dashboard_aggregation_service.go
Normal file
258
backend/internal/service/dashboard_aggregation_service.go
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultDashboardAggregationTimeout = 2 * time.Minute
|
||||||
|
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
|
||||||
|
dashboardAggregationRetentionInterval = 6 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||||
|
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||||
|
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||||
|
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||||
|
)
|
||||||
|
|
||||||
|
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
|
||||||
|
type DashboardAggregationRepository interface {
|
||||||
|
AggregateRange(ctx context.Context, start, end time.Time) error
|
||||||
|
GetAggregationWatermark(ctx context.Context) (time.Time, error)
|
||||||
|
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||||
|
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||||
|
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||||
|
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardAggregationService 负责定时聚合与回填。
|
||||||
|
type DashboardAggregationService struct {
|
||||||
|
repo DashboardAggregationRepository
|
||||||
|
timingWheel *TimingWheelService
|
||||||
|
cfg config.DashboardAggregationConfig
|
||||||
|
running int32
|
||||||
|
lastRetentionCleanup atomic.Value // time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDashboardAggregationService 创建聚合服务。
|
||||||
|
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||||
|
var aggCfg config.DashboardAggregationConfig
|
||||||
|
if cfg != nil {
|
||||||
|
aggCfg = cfg.DashboardAgg
|
||||||
|
}
|
||||||
|
return &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
timingWheel: timingWheel,
|
||||||
|
cfg: aggCfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动定时聚合作业(重启生效配置)。
|
||||||
|
func (s *DashboardAggregationService) Start() {
|
||||||
|
if s == nil || s.repo == nil || s.timingWheel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.cfg.Enabled {
|
||||||
|
log.Printf("[DashboardAggregation] 聚合作业已禁用")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.cfg.RecomputeDays > 0 {
|
||||||
|
go s.recomputeRecentDays()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
|
||||||
|
s.runScheduledAggregation()
|
||||||
|
})
|
||||||
|
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
|
||||||
|
if !s.cfg.BackfillEnabled {
|
||||||
|
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerBackfill 触发回填(异步)。
|
||||||
|
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
|
||||||
|
if s == nil || s.repo == nil {
|
||||||
|
return errors.New("聚合服务未初始化")
|
||||||
|
}
|
||||||
|
if !s.cfg.BackfillEnabled {
|
||||||
|
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
|
||||||
|
return ErrDashboardBackfillDisabled
|
||||||
|
}
|
||||||
|
if !end.After(start) {
|
||||||
|
return errors.New("回填时间范围无效")
|
||||||
|
}
|
||||||
|
if s.cfg.BackfillMaxDays > 0 {
|
||||||
|
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
|
||||||
|
if end.Sub(start) > maxRange {
|
||||||
|
return ErrDashboardBackfillTooLarge
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.backfillRange(ctx, start, end); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 回填失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||||
|
days := s.cfg.RecomputeDays
|
||||||
|
if days <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
start := now.AddDate(0, 0, -days)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.backfillRange(ctx, start, now); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||||
|
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer atomic.StoreInt32(&s.running, 0)
|
||||||
|
|
||||||
|
jobStart := time.Now().UTC()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
last, err := s.repo.GetAggregationWatermark(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
|
||||||
|
last = time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
|
||||||
|
epoch := time.Unix(0, 0).UTC()
|
||||||
|
start := last.Add(-lookback)
|
||||||
|
if !last.After(epoch) {
|
||||||
|
retentionDays := s.cfg.Retention.UsageLogsDays
|
||||||
|
if retentionDays <= 0 {
|
||||||
|
retentionDays = 1
|
||||||
|
}
|
||||||
|
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
|
||||||
|
} else if start.After(now) {
|
||||||
|
start = now.Add(-lookback)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.aggregateRange(ctx, start, now); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updateErr := s.repo.UpdateAggregationWatermark(ctx, now)
|
||||||
|
if updateErr != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||||
|
}
|
||||||
|
log.Printf("[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||||
|
start.Format(time.RFC3339),
|
||||||
|
now.Format(time.RFC3339),
|
||||||
|
time.Since(jobStart).String(),
|
||||||
|
updateErr == nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
s.maybeCleanupRetention(ctx, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||||
|
return errors.New("聚合作业正在运行")
|
||||||
|
}
|
||||||
|
defer atomic.StoreInt32(&s.running, 0)
|
||||||
|
|
||||||
|
jobStart := time.Now().UTC()
|
||||||
|
startUTC := start.UTC()
|
||||||
|
endUTC := end.UTC()
|
||||||
|
if !endUTC.After(startUTC) {
|
||||||
|
return errors.New("回填时间范围无效")
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor := truncateToDayUTC(startUTC)
|
||||||
|
for cursor.Before(endUTC) {
|
||||||
|
windowEnd := cursor.Add(24 * time.Hour)
|
||||||
|
if windowEnd.After(endUTC) {
|
||||||
|
windowEnd = endUTC
|
||||||
|
}
|
||||||
|
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cursor = windowEnd
|
||||||
|
}
|
||||||
|
|
||||||
|
updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC)
|
||||||
|
if updateErr != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||||
|
}
|
||||||
|
log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||||
|
startUTC.Format(time.RFC3339),
|
||||||
|
endUTC.Format(time.RFC3339),
|
||||||
|
time.Since(jobStart).String(),
|
||||||
|
updateErr == nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
s.maybeCleanupRetention(ctx, endUTC)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
if !end.After(start) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
|
||||||
|
}
|
||||||
|
return s.repo.AggregateRange(ctx, start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
|
||||||
|
lastAny := s.lastRetentionCleanup.Load()
|
||||||
|
if lastAny != nil {
|
||||||
|
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||||
|
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||||
|
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||||
|
|
||||||
|
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||||
|
if aggErr != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
|
||||||
|
}
|
||||||
|
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
|
||||||
|
if usageErr != nil {
|
||||||
|
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||||
|
}
|
||||||
|
if aggErr == nil && usageErr == nil {
|
||||||
|
s.lastRetentionCleanup.Store(now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateToDayUTC(t time.Time) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardAggregationRepoTestStub struct {
|
||||||
|
aggregateCalls int
|
||||||
|
lastStart time.Time
|
||||||
|
lastEnd time.Time
|
||||||
|
watermark time.Time
|
||||||
|
aggregateErr error
|
||||||
|
cleanupAggregatesErr error
|
||||||
|
cleanupUsageErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
s.aggregateCalls++
|
||||||
|
s.lastStart = start
|
||||||
|
s.lastEnd = end
|
||||||
|
return s.aggregateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||||
|
return s.watermark, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||||
|
return s.cleanupAggregatesErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
return s.cleanupUsageErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
IntervalSeconds: 60,
|
||||||
|
LookbackSeconds: 120,
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.runScheduledAggregation()
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.aggregateCalls)
|
||||||
|
require.False(t, repo.lastEnd.IsZero())
|
||||||
|
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||||
|
|
||||||
|
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
BackfillEnabled: true,
|
||||||
|
BackfillMaxDays: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now().AddDate(0, 0, -3)
|
||||||
|
end := time.Now()
|
||||||
|
err := svc.TriggerBackfill(start, end)
|
||||||
|
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
|
||||||
|
require.Equal(t, 0, repo.aggregateCalls)
|
||||||
|
}
|
||||||
@@ -2,25 +2,122 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DashboardService provides aggregated statistics for admin dashboard.
|
const (
|
||||||
type DashboardService struct {
|
defaultDashboardStatsFreshTTL = 15 * time.Second
|
||||||
usageRepo UsageLogRepository
|
defaultDashboardStatsCacheTTL = 30 * time.Second
|
||||||
|
defaultDashboardStatsRefreshTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
|
||||||
|
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
|
||||||
|
|
||||||
|
// DashboardStatsCache 定义仪表盘统计缓存接口。
|
||||||
|
type DashboardStatsCache interface {
|
||||||
|
GetDashboardStats(ctx context.Context) (string, error)
|
||||||
|
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
|
||||||
|
DeleteDashboardStats(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDashboardService(usageRepo UsageLogRepository) *DashboardService {
|
type dashboardStatsRangeFetcher interface {
|
||||||
|
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardStatsCacheEntry struct {
|
||||||
|
Stats *usagestats.DashboardStats `json:"stats"`
|
||||||
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardService 提供管理员仪表盘统计服务。
|
||||||
|
type DashboardService struct {
|
||||||
|
usageRepo UsageLogRepository
|
||||||
|
aggRepo DashboardAggregationRepository
|
||||||
|
cache DashboardStatsCache
|
||||||
|
cacheFreshTTL time.Duration
|
||||||
|
cacheTTL time.Duration
|
||||||
|
refreshTimeout time.Duration
|
||||||
|
refreshing int32
|
||||||
|
aggEnabled bool
|
||||||
|
aggInterval time.Duration
|
||||||
|
aggLookback time.Duration
|
||||||
|
aggUsageDays int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
|
||||||
|
freshTTL := defaultDashboardStatsFreshTTL
|
||||||
|
cacheTTL := defaultDashboardStatsCacheTTL
|
||||||
|
refreshTimeout := defaultDashboardStatsRefreshTimeout
|
||||||
|
aggEnabled := true
|
||||||
|
aggInterval := time.Minute
|
||||||
|
aggLookback := 2 * time.Minute
|
||||||
|
aggUsageDays := 90
|
||||||
|
if cfg != nil {
|
||||||
|
if !cfg.Dashboard.Enabled {
|
||||||
|
cache = nil
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
|
||||||
|
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsTTLSeconds > 0 {
|
||||||
|
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
|
||||||
|
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
|
||||||
|
}
|
||||||
|
aggEnabled = cfg.DashboardAgg.Enabled
|
||||||
|
if cfg.DashboardAgg.IntervalSeconds > 0 {
|
||||||
|
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.LookbackSeconds > 0 {
|
||||||
|
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
|
||||||
|
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if aggRepo == nil {
|
||||||
|
aggEnabled = false
|
||||||
|
}
|
||||||
return &DashboardService{
|
return &DashboardService{
|
||||||
usageRepo: usageRepo,
|
usageRepo: usageRepo,
|
||||||
|
aggRepo: aggRepo,
|
||||||
|
cache: cache,
|
||||||
|
cacheFreshTTL: freshTTL,
|
||||||
|
cacheTTL: cacheTTL,
|
||||||
|
refreshTimeout: refreshTimeout,
|
||||||
|
aggEnabled: aggEnabled,
|
||||||
|
aggInterval: aggInterval,
|
||||||
|
aggLookback: aggLookback,
|
||||||
|
aggUsageDays: aggUsageDays,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
stats, err := s.usageRepo.GetDashboardStats(ctx)
|
if s.cache != nil {
|
||||||
|
cached, fresh, err := s.getCachedDashboardStats(ctx)
|
||||||
|
if err == nil && cached != nil {
|
||||||
|
s.refreshAggregationStaleness(cached)
|
||||||
|
if !fresh {
|
||||||
|
s.refreshDashboardStatsAsync()
|
||||||
|
}
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := s.refreshDashboardStats(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
||||||
}
|
}
|
||||||
@@ -43,6 +140,169 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
||||||
|
data, err := s.cache.GetDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var entry dashboardStatsCacheEntry
|
||||||
|
if err := json.Unmarshal([]byte(data), &entry); err != nil {
|
||||||
|
s.evictDashboardStatsCache(err)
|
||||||
|
return nil, false, ErrDashboardStatsCacheMiss
|
||||||
|
}
|
||||||
|
if entry.Stats == nil {
|
||||||
|
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
|
||||||
|
return nil, false, ErrDashboardStatsCacheMiss
|
||||||
|
}
|
||||||
|
|
||||||
|
age := time.Since(time.Unix(entry.UpdatedAt, 0))
|
||||||
|
return entry.Stats, age <= s.cacheFreshTTL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
|
stats, err := s.fetchDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.applyAggregationStatus(ctx, stats)
|
||||||
|
cacheCtx, cancel := s.cacheOperationContext()
|
||||||
|
defer cancel()
|
||||||
|
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) refreshDashboardStatsAsync() {
|
||||||
|
if s.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer atomic.StoreInt32(&s.refreshing, 0)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stats, err := s.fetchDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.applyAggregationStatus(ctx, stats)
|
||||||
|
cacheCtx, cancel := s.cacheOperationContext()
|
||||||
|
defer cancel()
|
||||||
|
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
|
if !s.aggEnabled {
|
||||||
|
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
|
||||||
|
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.usageRepo.GetDashboardStats(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||||
|
if s.cache == nil || stats == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := dashboardStatsCacheEntry{
|
||||||
|
Stats: stats,
|
||||||
|
UpdatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(entry)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) evictDashboardStatsCache(reason error) {
|
||||||
|
if s.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cacheCtx, cancel := s.cacheOperationContext()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
|
||||||
|
}
|
||||||
|
if reason != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||||
|
if stats == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updatedAt := s.fetchAggregationUpdatedAt(ctx)
|
||||||
|
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
|
||||||
|
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
|
||||||
|
if stats == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
|
||||||
|
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
|
||||||
|
if s.aggRepo == nil {
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
if updatedAt.IsZero() {
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
return updatedAt.UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
|
||||||
|
if !s.aggEnabled {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
epoch := time.Unix(0, 0).UTC()
|
||||||
|
if !updatedAt.After(epoch) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
threshold := s.aggInterval + s.aggLookback
|
||||||
|
return now.Sub(updatedAt) > threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStatsUpdatedAt(raw string) time.Time {
|
||||||
|
if raw == "" {
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
parsed, err := time.Parse(time.RFC3339, raw)
|
||||||
|
if err != nil {
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
return parsed.UTC()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
387
backend/internal/service/dashboard_service_test.go
Normal file
387
backend/internal/service/dashboard_service_test.go
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type usageRepoStub struct {
|
||||||
|
UsageLogRepository
|
||||||
|
stats *usagestats.DashboardStats
|
||||||
|
rangeStats *usagestats.DashboardStats
|
||||||
|
err error
|
||||||
|
rangeErr error
|
||||||
|
calls int32
|
||||||
|
rangeCalls int32
|
||||||
|
rangeStart time.Time
|
||||||
|
rangeEnd time.Time
|
||||||
|
onCall chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
|
atomic.AddInt32(&s.calls, 1)
|
||||||
|
if s.onCall != nil {
|
||||||
|
select {
|
||||||
|
case s.onCall <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return s.stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
|
||||||
|
atomic.AddInt32(&s.rangeCalls, 1)
|
||||||
|
s.rangeStart = start
|
||||||
|
s.rangeEnd = end
|
||||||
|
if s.rangeErr != nil {
|
||||||
|
return nil, s.rangeErr
|
||||||
|
}
|
||||||
|
if s.rangeStats != nil {
|
||||||
|
return s.rangeStats, nil
|
||||||
|
}
|
||||||
|
return s.stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardCacheStub struct {
|
||||||
|
get func(ctx context.Context) (string, error)
|
||||||
|
set func(ctx context.Context, data string, ttl time.Duration) error
|
||||||
|
del func(ctx context.Context) error
|
||||||
|
getCalls int32
|
||||||
|
setCalls int32
|
||||||
|
delCalls int32
|
||||||
|
lastSetMu sync.Mutex
|
||||||
|
lastSet string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
|
||||||
|
atomic.AddInt32(&c.getCalls, 1)
|
||||||
|
if c.get != nil {
|
||||||
|
return c.get(ctx)
|
||||||
|
}
|
||||||
|
return "", ErrDashboardStatsCacheMiss
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||||
|
atomic.AddInt32(&c.setCalls, 1)
|
||||||
|
c.lastSetMu.Lock()
|
||||||
|
c.lastSet = data
|
||||||
|
c.lastSetMu.Unlock()
|
||||||
|
if c.set != nil {
|
||||||
|
return c.set(ctx, data, ttl)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
|
||||||
|
atomic.AddInt32(&c.delCalls, 1)
|
||||||
|
if c.del != nil {
|
||||||
|
return c.del(ctx)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardAggregationRepoStub struct {
|
||||||
|
watermark time.Time
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return time.Time{}, s.err
|
||||||
|
}
|
||||||
|
return s.watermark, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
|
||||||
|
t.Helper()
|
||||||
|
c.lastSetMu.Lock()
|
||||||
|
data := c.lastSet
|
||||||
|
c.lastSetMu.Unlock()
|
||||||
|
|
||||||
|
var entry dashboardStatsCacheEntry
|
||||||
|
err := json.Unmarshal([]byte(data), &entry)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheHitFresh(t *testing.T) {
|
||||||
|
stats := &usagestats.DashboardStats{
|
||||||
|
TotalUsers: 10,
|
||||||
|
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||||
|
StatsStale: true,
|
||||||
|
}
|
||||||
|
entry := dashboardStatsCacheEntry{
|
||||||
|
Stats: stats,
|
||||||
|
UpdatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(entry)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return string(payload), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageRepoStub{
|
||||||
|
stats: &usagestats.DashboardStats{TotalUsers: 99},
|
||||||
|
}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, stats, got)
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
|
||||||
|
stats := &usagestats.DashboardStats{
|
||||||
|
TotalUsers: 7,
|
||||||
|
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||||
|
StatsStale: true,
|
||||||
|
}
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return "", ErrDashboardStatsCacheMiss
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, stats, got)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
|
||||||
|
entry := cache.readLastEntry(t)
|
||||||
|
require.Equal(t, stats, entry.Stats)
|
||||||
|
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
|
||||||
|
stats := &usagestats.DashboardStats{
|
||||||
|
TotalUsers: 3,
|
||||||
|
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||||
|
StatsStale: true,
|
||||||
|
}
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, stats, got)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
|
||||||
|
staleStats := &usagestats.DashboardStats{
|
||||||
|
TotalUsers: 11,
|
||||||
|
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||||
|
StatsStale: true,
|
||||||
|
}
|
||||||
|
entry := dashboardStatsCacheEntry{
|
||||||
|
Stats: staleStats,
|
||||||
|
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(entry)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return string(payload), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
refreshCh := make(chan struct{}, 1)
|
||||||
|
repo := &usageRepoStub{
|
||||||
|
stats: &usagestats.DashboardStats{TotalUsers: 22},
|
||||||
|
onCall: refreshCh,
|
||||||
|
}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, staleStats, got)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-refreshCh:
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("等待异步刷新超时")
|
||||||
|
}
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return atomic.LoadInt32(&cache.setCalls) >= 1
|
||||||
|
}, 1*time.Second, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return "not-json", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stats := &usagestats.DashboardStats{TotalUsers: 9}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, stats, got)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return "not-json", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageRepoStub{err: errors.New("db down")}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
_, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
|
||||||
|
stats := &usagestats.DashboardStats{}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
|
||||||
|
require.True(t, got.StatsStale)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
|
||||||
|
aggNow := time.Now().UTC().Truncate(time.Second)
|
||||||
|
stats := &usagestats.DashboardStats{}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
IntervalSeconds: 60,
|
||||||
|
LookbackSeconds: 120,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
|
||||||
|
require.False(t, got.StatsStale)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
|
||||||
|
expected := &usagestats.DashboardStats{TotalUsers: 42}
|
||||||
|
repo := &usageRepoStub{
|
||||||
|
rangeStats: expected,
|
||||||
|
err: errors.New("should not call aggregated stats"),
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 7,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, nil, nil, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(42), got.TotalUsers)
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
|
||||||
|
require.False(t, repo.rangeEnd.IsZero())
|
||||||
|
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
|
||||||
|
}
|
||||||
@@ -63,6 +63,9 @@ const (
|
|||||||
SubscriptionStatusSuspended = "suspended"
|
SubscriptionStatusSuspended = "suspended"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||||
|
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||||
|
|
||||||
// Setting keys
|
// Setting keys
|
||||||
const (
|
const (
|
||||||
// 注册设置
|
// 注册设置
|
||||||
@@ -83,6 +86,12 @@ const (
|
|||||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||||
|
|
||||||
|
// LinuxDo Connect OAuth 登录设置
|
||||||
|
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||||
|
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||||
|
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
||||||
|
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||||
|
|
||||||
// OEM设置
|
// OEM设置
|
||||||
SettingKeySiteName = "site_name" // 网站名称
|
SettingKeySiteName = "site_name" // 网站名称
|
||||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||||
@@ -113,16 +122,38 @@ const (
|
|||||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||||
|
|
||||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
// =========================
|
||||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
// Ops Monitoring (vNext)
|
||||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
// =========================
|
||||||
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
|
||||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
// SettingKeyOpsMonitoringEnabled is a DB-backed soft switch to enable/disable ops module at runtime.
|
||||||
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
|
SettingKeyOpsMonitoringEnabled = "ops_monitoring_enabled"
|
||||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
|
||||||
|
// SettingKeyOpsRealtimeMonitoringEnabled controls realtime features (e.g. WS/QPS push).
|
||||||
|
SettingKeyOpsRealtimeMonitoringEnabled = "ops_realtime_monitoring_enabled"
|
||||||
|
|
||||||
|
// SettingKeyOpsQueryModeDefault controls the default query mode for ops dashboard (auto/raw/preagg).
|
||||||
|
SettingKeyOpsQueryModeDefault = "ops_query_mode_default"
|
||||||
|
|
||||||
|
// SettingKeyOpsEmailNotificationConfig stores JSON config for ops email notifications.
|
||||||
|
SettingKeyOpsEmailNotificationConfig = "ops_email_notification_config"
|
||||||
|
|
||||||
|
// SettingKeyOpsAlertRuntimeSettings stores JSON config for ops alert evaluator runtime settings.
|
||||||
|
SettingKeyOpsAlertRuntimeSettings = "ops_alert_runtime_settings"
|
||||||
|
|
||||||
|
// SettingKeyOpsMetricsIntervalSeconds controls the ops metrics collector interval (>=60).
|
||||||
|
SettingKeyOpsMetricsIntervalSeconds = "ops_metrics_interval_seconds"
|
||||||
|
|
||||||
|
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
|
||||||
|
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
|
||||||
|
|
||||||
|
// =========================
|
||||||
|
// Stream Timeout Handling
|
||||||
|
// =========================
|
||||||
|
|
||||||
|
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||||
|
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||||
|
)
|
||||||
|
|
||||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||||
const AdminAPIKeyPrefix = "admin-"
|
const AdminAPIKeyPrefix = "admin-"
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ type GatewayService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
@@ -169,6 +170,7 @@ func NewGatewayService(
|
|||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService,
|
||||||
concurrencyService *ConcurrencyService,
|
concurrencyService *ConcurrencyService,
|
||||||
billingService *BillingService,
|
billingService *BillingService,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
@@ -185,6 +187,7 @@ func NewGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
@@ -745,6 +748,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
|
}
|
||||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||||
if useMixed {
|
if useMixed {
|
||||||
platforms := []string{platform, PlatformAntigravity}
|
platforms := []string{platform, PlatformAntigravity}
|
||||||
@@ -821,6 +827,13 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
|||||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
return s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||||
sort.SliceStable(accounts, func(i, j int) bool {
|
sort.SliceStable(accounts, func(i, j int) bool {
|
||||||
a, b := accounts[i], accounts[j]
|
a, b := accounts[i], accounts[j]
|
||||||
@@ -851,7 +864,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
@@ -864,16 +877,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取可调度账号列表(单平台)
|
// 2. 获取可调度账号列表(单平台)
|
||||||
var accounts []Account
|
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||||
var err error
|
if hasForcePlatform && forcePlatform == "" {
|
||||||
if s.cfg.RunMode == config.RunModeSimple {
|
hasForcePlatform = false
|
||||||
// 简易模式:忽略 groupID,查询所有可用账号
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
|
||||||
} else if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
|
||||||
}
|
}
|
||||||
|
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -935,7 +943,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
|
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
|
||||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||||
platforms := []string{nativePlatform, PlatformAntigravity}
|
|
||||||
preferOAuth := nativePlatform == PlatformGemini
|
preferOAuth := nativePlatform == PlatformGemini
|
||||||
|
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
@@ -943,7 +950,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||||
@@ -958,13 +965,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取可调度账号列表
|
// 2. 获取可调度账号列表
|
||||||
var accounts []Account
|
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||||
var err error
|
|
||||||
if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1226,6 +1227,9 @@ func enforceCacheControlLimit(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段)
|
||||||
|
removeCacheControlFromThinkingBlocks(data)
|
||||||
|
|
||||||
// 计算当前 cache_control 块数量
|
// 计算当前 cache_control 块数量
|
||||||
count := countCacheControlBlocks(data)
|
count := countCacheControlBlocks(data)
|
||||||
if count <= maxCacheControlBlocks {
|
if count <= maxCacheControlBlocks {
|
||||||
@@ -1253,6 +1257,7 @@ func enforceCacheControlLimit(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
|
||||||
|
// 注意:thinking 块不支持 cache_control,统计时跳过
|
||||||
func countCacheControlBlocks(data map[string]any) int {
|
func countCacheControlBlocks(data map[string]any) int {
|
||||||
count := 0
|
count := 0
|
||||||
|
|
||||||
@@ -1260,6 +1265,10 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
if system, ok := data["system"].([]any); ok {
|
if system, ok := data["system"].([]any); ok {
|
||||||
for _, item := range system {
|
for _, item := range system {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
@@ -1274,6 +1283,10 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
if content, ok := msgMap["content"].([]any); ok {
|
if content, ok := msgMap["content"].([]any); ok {
|
||||||
for _, item := range content {
|
for _, item := range content {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
@@ -1289,6 +1302,7 @@ func countCacheControlBlocks(data map[string]any) int {
|
|||||||
|
|
||||||
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始)
|
||||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||||
|
// 注意:跳过 thinking 块(它不支持 cache_control)
|
||||||
func removeCacheControlFromMessages(data map[string]any) bool {
|
func removeCacheControlFromMessages(data map[string]any) bool {
|
||||||
messages, ok := data["messages"].([]any)
|
messages, ok := data["messages"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -1306,6 +1320,10 @@ func removeCacheControlFromMessages(data map[string]any) bool {
|
|||||||
}
|
}
|
||||||
for _, item := range content {
|
for _, item := range content {
|
||||||
if m, ok := item.(map[string]any); ok {
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
delete(m, "cache_control")
|
delete(m, "cache_control")
|
||||||
return true
|
return true
|
||||||
@@ -1318,6 +1336,7 @@ func removeCacheControlFromMessages(data map[string]any) bool {
|
|||||||
|
|
||||||
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
// removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt)
|
||||||
// 返回 true 表示成功移除,false 表示没有可移除的
|
// 返回 true 表示成功移除,false 表示没有可移除的
|
||||||
|
// 注意:跳过 thinking 块(它不支持 cache_control)
|
||||||
func removeCacheControlFromSystem(data map[string]any) bool {
|
func removeCacheControlFromSystem(data map[string]any) bool {
|
||||||
system, ok := data["system"].([]any)
|
system, ok := data["system"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -1327,6 +1346,10 @@ func removeCacheControlFromSystem(data map[string]any) bool {
|
|||||||
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
// 从尾部开始移除,保护开头注入的 Claude Code prompt
|
||||||
for i := len(system) - 1; i >= 0; i-- {
|
for i := len(system) - 1; i >= 0; i-- {
|
||||||
if m, ok := system[i].(map[string]any); ok {
|
if m, ok := system[i].(map[string]any); ok {
|
||||||
|
// thinking 块不支持 cache_control,跳过
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, has := m["cache_control"]; has {
|
if _, has := m["cache_control"]; has {
|
||||||
delete(m, "cache_control")
|
delete(m, "cache_control")
|
||||||
return true
|
return true
|
||||||
@@ -1336,6 +1359,44 @@ func removeCacheControlFromSystem(data map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control
|
||||||
|
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段
|
||||||
|
func removeCacheControlFromThinkingBlocks(data map[string]any) {
|
||||||
|
// 清理 system 中的 thinking 块
|
||||||
|
if system, ok := data["system"].([]any); ok {
|
||||||
|
for _, item := range system {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
if _, has := m["cache_control"]; has {
|
||||||
|
delete(m, "cache_control")
|
||||||
|
log.Printf("[Warning] Removed illegal cache_control from thinking block in system")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理 messages 中的 thinking 块
|
||||||
|
if messages, ok := data["messages"].([]any); ok {
|
||||||
|
for msgIdx, msg := range messages {
|
||||||
|
if msgMap, ok := msg.(map[string]any); ok {
|
||||||
|
if content, ok := msgMap["content"].([]any); ok {
|
||||||
|
for contentIdx, item := range content {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if blockType, _ := m["type"].(string); blockType == "thinking" {
|
||||||
|
if _, has := m["cache_control"]; has {
|
||||||
|
delete(m, "cache_control")
|
||||||
|
log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Forward 转发请求到Claude API
|
// Forward 转发请求到Claude API
|
||||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@@ -1399,7 +1460,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
if resp != nil && resp.Body != nil {
|
if resp != nil && resp.Body != nil {
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream request failed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 优先检测thinking block签名错误(400)并重试一次
|
// 优先检测thinking block签名错误(400)并重试一次
|
||||||
@@ -1409,6 +1487,21 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
if s.isThinkingBlockSignatureError(respBody) {
|
if s.isThinkingBlockSignatureError(respBody) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "signature_error",
|
||||||
|
Message: extractUpstreamErrorMessage(respBody),
|
||||||
|
Detail: func() string {
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
})
|
||||||
|
|
||||||
looksLikeToolSignatureError := func(msg string) bool {
|
looksLikeToolSignatureError := func(msg string) bool {
|
||||||
m := strings.ToLower(msg)
|
m := strings.ToLower(msg)
|
||||||
return strings.Contains(m, "tool_use") ||
|
return strings.Contains(m, "tool_use") ||
|
||||||
@@ -1445,6 +1538,20 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||||
_ = retryResp.Body.Close()
|
_ = retryResp.Body.Close()
|
||||||
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) {
|
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: retryResp.StatusCode,
|
||||||
|
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||||
|
Kind: "signature_retry_thinking",
|
||||||
|
Message: extractUpstreamErrorMessage(retryRespBody),
|
||||||
|
Detail: func() string {
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
return truncateString(string(retryRespBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
})
|
||||||
msg2 := extractUpstreamErrorMessage(retryRespBody)
|
msg2 := extractUpstreamErrorMessage(retryRespBody)
|
||||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||||
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||||
@@ -1459,6 +1566,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
if retryResp2 != nil && retryResp2.Body != nil {
|
if retryResp2 != nil && retryResp2.Body != nil {
|
||||||
_ = retryResp2.Body.Close()
|
_ = retryResp2.Body.Close()
|
||||||
}
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "signature_retry_tools_request_error",
|
||||||
|
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
|
||||||
|
})
|
||||||
log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
|
log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
|
log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
|
||||||
@@ -1508,9 +1622,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "retry",
|
||||||
|
Message: extractUpstreamErrorMessage(respBody),
|
||||||
|
Detail: func() string {
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
})
|
||||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
|
log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
|
||||||
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
|
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
|
||||||
_ = resp.Body.Close()
|
|
||||||
if err := sleepWithContext(ctx, delay); err != nil {
|
if err := sleepWithContext(ctx, delay); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1538,7 +1667,25 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// 处理重试耗尽的情况
|
// 处理重试耗尽的情况
|
||||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "retry_exhausted_failover",
|
||||||
|
Message: extractUpstreamErrorMessage(respBody),
|
||||||
|
Detail: func() string {
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
})
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||||
@@ -1546,7 +1693,25 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
|
|
||||||
// 处理可切换账号的错误
|
// 处理可切换账号的错误
|
||||||
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
s.handleFailoverSideEffects(ctx, resp, account)
|
s.handleFailoverSideEffects(ctx, resp, account)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: extractUpstreamErrorMessage(respBody),
|
||||||
|
Detail: func() string {
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}(),
|
||||||
|
})
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1563,6 +1728,26 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
if s.shouldFailoverOn400(respBody) {
|
if s.shouldFailoverOn400(respBody) {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover_on_400",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
if s.cfg.Gateway.LogUpstreamErrorBody {
|
if s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
log.Printf(
|
log.Printf(
|
||||||
"Account %d: 400 error, attempting failover: %s",
|
"Account %d: 400 error, attempting failover: %s",
|
||||||
@@ -1859,7 +2044,30 @@ func extractUpstreamErrorMessage(body []byte) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
|
||||||
|
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(body), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
// 处理上游错误,标记账号状态
|
// 处理上游错误,标记账号状态
|
||||||
shouldDisable := false
|
shouldDisable := false
|
||||||
@@ -1870,24 +2078,33 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
|||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
log.Printf(
|
||||||
|
"Upstream error %d (account=%d platform=%s type=%s): %s",
|
||||||
|
resp.StatusCode,
|
||||||
|
account.ID,
|
||||||
|
account.Platform,
|
||||||
|
account.Type,
|
||||||
|
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
||||||
var errType, errMsg string
|
var errType, errMsg string
|
||||||
var statusCode int
|
var statusCode int
|
||||||
|
|
||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
case 400:
|
case 400:
|
||||||
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
||||||
log.Printf(
|
|
||||||
"Upstream 400 error (account=%d platform=%s type=%s): %s",
|
|
||||||
account.ID,
|
|
||||||
account.Platform,
|
|
||||||
account.Type,
|
|
||||||
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
c.Data(http.StatusBadRequest, "application/json", body)
|
c.Data(http.StatusBadRequest, "application/json", body)
|
||||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
summary := upstreamMsg
|
||||||
|
if summary == "" {
|
||||||
|
summary = truncateForLog(body, 512)
|
||||||
|
}
|
||||||
|
if summary == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, summary)
|
||||||
case 401:
|
case 401:
|
||||||
statusCode = http.StatusBadGateway
|
statusCode = http.StatusBadGateway
|
||||||
errType = "upstream_error"
|
errType = "upstream_error"
|
||||||
@@ -1923,11 +2140,14 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
statusCode := resp.StatusCode
|
statusCode := resp.StatusCode
|
||||||
|
|
||||||
// OAuth/Setup Token 账号的 403:标记账号异常
|
// OAuth/Setup Token 账号的 403:标记账号异常
|
||||||
@@ -1941,7 +2161,7 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1949,8 +2169,45 @@ func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *ht
|
|||||||
// OAuth 403:标记账号异常
|
// OAuth 403:标记账号异常
|
||||||
// API Key 未配置错误码:仅返回错误,不标记账号
|
// API Key 未配置错误码:仅返回错误,不标记账号
|
||||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||||
|
// Capture upstream error body before side-effects consume the stream.
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "retry_exhausted",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
log.Printf(
|
||||||
|
"Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s",
|
||||||
|
resp.StatusCode,
|
||||||
|
account.ID,
|
||||||
|
account.Platform,
|
||||||
|
account.Type,
|
||||||
|
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// 返回统一的重试耗尽错误响应
|
// 返回统一的重试耗尽错误响应
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
"type": "error",
|
"type": "error",
|
||||||
@@ -1960,7 +2217,10 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (retries exhausted) message=%s", resp.StatusCode, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// streamingResult 流式响应结果
|
// streamingResult 流式响应结果
|
||||||
@@ -2141,6 +2401,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
}
|
}
|
||||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||||
|
}
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
}
|
}
|
||||||
@@ -2490,6 +2754,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||||
return fmt.Errorf("upstream request failed: %w", err)
|
return fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2527,6 +2792,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
// 标记账号状态(429/529等)
|
// 标记账号状态(429/529等)
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
|
||||||
// 记录上游错误摘要便于排障(不回显请求内容)
|
// 记录上游错误摘要便于排障(不回显请求内容)
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
log.Printf(
|
log.Printf(
|
||||||
@@ -2548,7 +2825,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
errMsg = "Service overloaded"
|
errMsg = "Service overloaded"
|
||||||
}
|
}
|
||||||
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
|
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
|
||||||
return fmt.Errorf("upstream error: %d", resp.StatusCode)
|
if upstreamMsg == "" {
|
||||||
|
return fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 透传成功响应
|
// 透传成功响应
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ type GeminiMessagesCompatService struct {
|
|||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
tokenProvider *GeminiTokenProvider
|
tokenProvider *GeminiTokenProvider
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
@@ -51,6 +52,7 @@ func NewGeminiMessagesCompatService(
|
|||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
groupRepo GroupRepository,
|
groupRepo GroupRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService,
|
||||||
tokenProvider *GeminiTokenProvider,
|
tokenProvider *GeminiTokenProvider,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
@@ -61,6 +63,7 @@ func NewGeminiMessagesCompatService(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
tokenProvider: tokenProvider,
|
tokenProvider: tokenProvider,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
@@ -105,12 +108,6 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
// 注意:强制平台模式不走混合调度
|
// 注意:强制平台模式不走混合调度
|
||||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||||
var queryPlatforms []string
|
|
||||||
if useMixedScheduling {
|
|
||||||
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
|
||||||
} else {
|
|
||||||
queryPlatforms = []string{platform}
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheKey := "gemini:" + sessionHash
|
cacheKey := "gemini:" + sessionHash
|
||||||
|
|
||||||
@@ -118,7 +115,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
valid := false
|
valid := false
|
||||||
@@ -149,22 +146,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||||
var accounts []Account
|
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||||
var err error
|
if err != nil {
|
||||||
if groupID != nil {
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
}
|
||||||
|
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||||
|
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
|
||||||
|
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
|
||||||
if len(accounts) == 0 && hasForcePlatform {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var selected *Account
|
var selected *Account
|
||||||
@@ -245,6 +236,31 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
|||||||
return s.antigravityGatewayService
|
return s.antigravityGatewayService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
return s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||||
|
return accounts, err
|
||||||
|
}
|
||||||
|
|
||||||
|
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||||
|
queryPlatforms := []string{platform}
|
||||||
|
if useMixedScheduling {
|
||||||
|
queryPlatforms = []string{platform, PlatformAntigravity}
|
||||||
|
}
|
||||||
|
|
||||||
|
if groupID != nil {
|
||||||
|
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||||
|
}
|
||||||
|
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||||
@@ -266,13 +282,7 @@ func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (strin
|
|||||||
|
|
||||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||||
var accounts []Account
|
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformAntigravity, false)
|
||||||
var err error
|
|
||||||
if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -288,13 +298,7 @@ func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context
|
|||||||
// 3) OAuth accounts explicitly marked as ai_studio
|
// 3) OAuth accounts explicitly marked as ai_studio
|
||||||
// 4) Any remaining Gemini accounts (fallback)
|
// 4) Any remaining Gemini accounts (fallback)
|
||||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
||||||
var accounts []Account
|
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformGemini, true)
|
||||||
var err error
|
|
||||||
if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -543,12 +547,21 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
if attempt < geminiMaxRetries {
|
if attempt < geminiMaxRetries {
|
||||||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||||
sleepGeminiBackoff(attempt)
|
sleepGeminiBackoff(attempt)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
|
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
|
||||||
@@ -558,6 +571,30 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
if isGeminiSignatureRelatedError(respBody) {
|
if isGeminiSignatureRelatedError(respBody) {
|
||||||
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
|
if upstreamReqID == "" {
|
||||||
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
|
}
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: upstreamReqID,
|
||||||
|
Kind: "signature_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
var strippedClaudeBody []byte
|
var strippedClaudeBody []byte
|
||||||
stageName := ""
|
stageName := ""
|
||||||
switch signatureRetryStage {
|
switch signatureRetryStage {
|
||||||
@@ -608,6 +645,30 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
}
|
}
|
||||||
if attempt < geminiMaxRetries {
|
if attempt < geminiMaxRetries {
|
||||||
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
|
if upstreamReqID == "" {
|
||||||
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
|
}
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: upstreamReqID,
|
||||||
|
Kind: "retry",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||||||
sleepGeminiBackoff(attempt)
|
sleepGeminiBackoff(attempt)
|
||||||
continue
|
continue
|
||||||
@@ -633,12 +694,62 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
if tempMatched {
|
if tempMatched {
|
||||||
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
|
if upstreamReqID == "" {
|
||||||
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
|
}
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: upstreamReqID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||||
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
|
if upstreamReqID == "" {
|
||||||
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
|
}
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: upstreamReqID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
|
if upstreamReqID == "" {
|
||||||
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
|
}
|
||||||
|
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
requestID := resp.Header.Get(requestIDHeader)
|
requestID := resp.Header.Get(requestIDHeader)
|
||||||
@@ -863,6 +974,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
if attempt < geminiMaxRetries {
|
if attempt < geminiMaxRetries {
|
||||||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||||
sleepGeminiBackoff(attempt)
|
sleepGeminiBackoff(attempt)
|
||||||
@@ -880,7 +999,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
FirstTokenMs: nil,
|
FirstTokenMs: nil,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||||
@@ -899,6 +1019,30 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
}
|
}
|
||||||
if attempt < geminiMaxRetries {
|
if attempt < geminiMaxRetries {
|
||||||
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
|
if upstreamReqID == "" {
|
||||||
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
|
}
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: upstreamReqID,
|
||||||
|
Kind: "retry",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||||||
sleepGeminiBackoff(attempt)
|
sleepGeminiBackoff(attempt)
|
||||||
continue
|
continue
|
||||||
@@ -962,19 +1106,84 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
}
|
}
|
||||||
|
|
||||||
if tempMatched {
|
if tempMatched {
|
||||||
|
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: requestID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||||
|
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: requestID,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: requestID,
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
if contentType == "" {
|
if contentType == "" {
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
}
|
}
|
||||||
c.Data(resp.StatusCode, contentType, respBody)
|
c.Data(resp.StatusCode, contentType, respBody)
|
||||||
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("gemini upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
var usage *ClaudeUsage
|
var usage *ClaudeUsage
|
||||||
@@ -1076,7 +1285,32 @@ func sanitizeUpstreamErrorMessage(msg string) string {
|
|||||||
return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`)
|
return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, upstreamStatus int, body []byte) error {
|
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(body), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: upstreamStatus,
|
||||||
|
UpstreamRequestID: upstreamRequestID,
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
|
||||||
|
}
|
||||||
|
|
||||||
var statusCode int
|
var statusCode int
|
||||||
var errType, errMsg string
|
var errType, errMsg string
|
||||||
|
|
||||||
@@ -1184,7 +1418,10 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, ups
|
|||||||
"type": "error",
|
"type": "error",
|
||||||
"error": gin.H{"type": errType, "message": errMsg},
|
"error": gin.H{"type": errType, "message": errMsg},
|
||||||
})
|
})
|
||||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
if upstreamMsg == "" {
|
||||||
|
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
type claudeErrorMapping struct {
|
type claudeErrorMapping struct {
|
||||||
|
|||||||
@@ -50,13 +50,15 @@ type UpdateGroupRequest struct {
|
|||||||
|
|
||||||
// GroupService 分组管理服务
|
// GroupService 分组管理服务
|
||||||
type GroupService struct {
|
type GroupService struct {
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupService 创建分组服务实例
|
// NewGroupService 创建分组服务实例
|
||||||
func NewGroupService(groupRepo GroupRepository) *GroupService {
|
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
|
||||||
return &GroupService{
|
return &GroupService{
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
|
|||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, fmt.Errorf("update group: %w", err)
|
return nil, fmt.Errorf("update group: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
@@ -167,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
|||||||
return fmt.Errorf("get group: %w", err)
|
return fmt.Errorf("get group: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
||||||
return fmt.Errorf("delete group: %w", err)
|
return fmt.Errorf("delete group: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
528
backend/internal/service/openai_codex_transform.go
Normal file
528
backend/internal/service/openai_codex_transform.go
Normal file
@@ -0,0 +1,528 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
|
||||||
|
codexCacheTTL = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed prompts/codex_cli_instructions.md
|
||||||
|
var codexCLIInstructions string
|
||||||
|
|
||||||
|
var codexModelMap = map[string]string{
|
||||||
|
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||||
|
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||||
|
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||||
|
"gpt-5.1-codex-high": "gpt-5.1-codex",
|
||||||
|
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.2": "gpt-5.2",
|
||||||
|
"gpt-5.2-none": "gpt-5.2",
|
||||||
|
"gpt-5.2-low": "gpt-5.2",
|
||||||
|
"gpt-5.2-medium": "gpt-5.2",
|
||||||
|
"gpt-5.2-high": "gpt-5.2",
|
||||||
|
"gpt-5.2-xhigh": "gpt-5.2",
|
||||||
|
"gpt-5.2-codex": "gpt-5.2-codex",
|
||||||
|
"gpt-5.2-codex-low": "gpt-5.2-codex",
|
||||||
|
"gpt-5.2-codex-medium": "gpt-5.2-codex",
|
||||||
|
"gpt-5.2-codex-high": "gpt-5.2-codex",
|
||||||
|
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
|
||||||
|
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5.1": "gpt-5.1",
|
||||||
|
"gpt-5.1-none": "gpt-5.1",
|
||||||
|
"gpt-5.1-low": "gpt-5.1",
|
||||||
|
"gpt-5.1-medium": "gpt-5.1",
|
||||||
|
"gpt-5.1-high": "gpt-5.1",
|
||||||
|
"gpt-5.1-chat-latest": "gpt-5.1",
|
||||||
|
"gpt-5-codex": "gpt-5.1-codex",
|
||||||
|
"codex-mini-latest": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
"gpt-5-mini": "gpt-5.1",
|
||||||
|
"gpt-5-nano": "gpt-5.1",
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexTransformResult struct {
|
||||||
|
Modified bool
|
||||||
|
NormalizedModel string
|
||||||
|
PromptCacheKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type opencodeCacheMetadata struct {
|
||||||
|
ETag string `json:"etag"`
|
||||||
|
LastFetch string `json:"lastFetch,omitempty"`
|
||||||
|
LastChecked int64 `json:"lastChecked"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||||
|
result := codexTransformResult{}
|
||||||
|
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||||
|
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||||
|
|
||||||
|
model := ""
|
||||||
|
if v, ok := reqBody["model"].(string); ok {
|
||||||
|
model = v
|
||||||
|
}
|
||||||
|
normalizedModel := normalizeCodexModel(model)
|
||||||
|
if normalizedModel != "" {
|
||||||
|
if model != normalizedModel {
|
||||||
|
reqBody["model"] = normalizedModel
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
result.NormalizedModel = normalizedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。
|
||||||
|
// 避免上游返回 "Store must be set to false"。
|
||||||
|
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||||
|
reqBody["store"] = false
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||||
|
reqBody["stream"] = true
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
if _, ok := reqBody["max_completion_tokens"]; ok {
|
||||||
|
delete(reqBody, "max_completion_tokens")
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if normalizeCodexTools(reqBody) {
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||||
|
result.PromptCacheKey = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||||
|
existingInstructions, _ := reqBody["instructions"].(string)
|
||||||
|
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||||
|
|
||||||
|
if instructions != "" {
|
||||||
|
if existingInstructions != instructions {
|
||||||
|
reqBody["instructions"] = instructions
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
} else if existingInstructions == "" {
|
||||||
|
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
|
||||||
|
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||||
|
if codexInstructions != "" {
|
||||||
|
reqBody["instructions"] = codexInstructions
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||||
|
if input, ok := reqBody["input"].([]any); ok {
|
||||||
|
input = filterCodexInput(input, needsToolContinuation)
|
||||||
|
reqBody["input"] = input
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeCodexModel(model string) string {
|
||||||
|
if model == "" {
|
||||||
|
return "gpt-5.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID := model
|
||||||
|
if strings.Contains(modelID, "/") {
|
||||||
|
parts := strings.Split(modelID, "/")
|
||||||
|
modelID = parts[len(parts)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := strings.ToLower(modelID)
|
||||||
|
|
||||||
|
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
|
||||||
|
return "gpt-5.2-codex"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||||
|
return "gpt-5.2"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||||
|
return "gpt-5.1-codex-max"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
|
||||||
|
return "gpt-5.1-codex-mini"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "codex-mini-latest") ||
|
||||||
|
strings.Contains(normalized, "gpt-5-codex-mini") ||
|
||||||
|
strings.Contains(normalized, "gpt 5 codex mini") {
|
||||||
|
return "codex-mini-latest"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
|
||||||
|
return "gpt-5.1-codex"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
|
||||||
|
return "gpt-5.1"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "codex") {
|
||||||
|
return "gpt-5.1-codex"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
|
||||||
|
return "gpt-5.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "gpt-5.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNormalizedCodexModel(modelID string) string {
|
||||||
|
if modelID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if mapped, ok := codexModelMap[modelID]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(modelID)
|
||||||
|
for key, value := range codexModelMap {
|
||||||
|
if strings.ToLower(key) == lower {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
||||||
|
cacheDir := codexCachePath("")
|
||||||
|
if cacheDir == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
cacheFile := filepath.Join(cacheDir, cacheFileName)
|
||||||
|
metaFile := filepath.Join(cacheDir, metaFileName)
|
||||||
|
|
||||||
|
var cachedContent string
|
||||||
|
if content, ok := readFile(cacheFile); ok {
|
||||||
|
cachedContent = content
|
||||||
|
}
|
||||||
|
|
||||||
|
var meta opencodeCacheMetadata
|
||||||
|
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
|
||||||
|
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
|
||||||
|
return cachedContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
content, etag, status, err := fetchWithETag(url, meta.ETag)
|
||||||
|
if err == nil && status == http.StatusNotModified && cachedContent != "" {
|
||||||
|
return cachedContent
|
||||||
|
}
|
||||||
|
if err == nil && status >= 200 && status < 300 && content != "" {
|
||||||
|
_ = writeFile(cacheFile, content)
|
||||||
|
meta = opencodeCacheMetadata{
|
||||||
|
ETag: etag,
|
||||||
|
LastFetch: time.Now().UTC().Format(time.RFC3339),
|
||||||
|
LastChecked: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
_ = writeJSON(metaFile, meta)
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
return cachedContent
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOpenCodeCodexHeader() string {
|
||||||
|
// 优先从 opencode 仓库缓存获取指令。
|
||||||
|
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||||
|
|
||||||
|
// 若 opencode 指令可用,直接返回。
|
||||||
|
if opencodeInstructions != "" {
|
||||||
|
return opencodeInstructions
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则回退使用本地 Codex CLI 指令。
|
||||||
|
return getCodexCLIInstructions()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCodexCLIInstructions() string {
|
||||||
|
return codexCLIInstructions
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOpenCodeInstructions() string {
|
||||||
|
return getOpenCodeCodexHeader()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
|
||||||
|
func GetCodexCLIInstructions() string {
|
||||||
|
return getCodexCLIInstructions()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
|
||||||
|
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
||||||
|
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||||
|
if codexInstructions == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
existingInstructions, _ := reqBody["instructions"].(string)
|
||||||
|
if strings.TrimSpace(existingInstructions) != codexInstructions {
|
||||||
|
reqBody["instructions"] = codexInstructions
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
|
||||||
|
func IsInstructionError(errorMessage string) bool {
|
||||||
|
if errorMessage == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerMsg := strings.ToLower(errorMessage)
|
||||||
|
instructionKeywords := []string{
|
||||||
|
"instruction",
|
||||||
|
"instructions",
|
||||||
|
"system prompt",
|
||||||
|
"system message",
|
||||||
|
"invalid prompt",
|
||||||
|
"prompt format",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, keyword := range instructionKeywords {
|
||||||
|
if strings.Contains(lowerMsg, keyword) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterCodexInput 按需过滤 item_reference 与 id。
|
||||||
|
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
|
||||||
|
func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||||
|
filtered := make([]any, 0, len(input))
|
||||||
|
for _, item := range input {
|
||||||
|
m, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
typ, _ := m["type"].(string)
|
||||||
|
if typ == "item_reference" {
|
||||||
|
if !preserveReferences {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newItem := make(map[string]any, len(m))
|
||||||
|
for key, value := range m {
|
||||||
|
newItem[key] = value
|
||||||
|
}
|
||||||
|
filtered = append(filtered, newItem)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newItem := m
|
||||||
|
copied := false
|
||||||
|
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
|
||||||
|
ensureCopy := func() {
|
||||||
|
if copied {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newItem = make(map[string]any, len(m))
|
||||||
|
for key, value := range m {
|
||||||
|
newItem[key] = value
|
||||||
|
}
|
||||||
|
copied = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCodexToolCallItemType(typ) {
|
||||||
|
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
|
||||||
|
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
||||||
|
ensureCopy()
|
||||||
|
newItem["call_id"] = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !preserveReferences {
|
||||||
|
ensureCopy()
|
||||||
|
delete(newItem, "id")
|
||||||
|
if !isCodexToolCallItemType(typ) {
|
||||||
|
delete(newItem, "call_id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered = append(filtered, newItem)
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCodexToolCallItemType(typ string) bool {
|
||||||
|
if typ == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeCodexTools(reqBody map[string]any) bool {
|
||||||
|
rawTools, ok := reqBody["tools"]
|
||||||
|
if !ok || rawTools == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tools, ok := rawTools.([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
for idx, tool := range tools {
|
||||||
|
toolMap, ok := tool.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
toolType, _ := toolMap["type"].(string)
|
||||||
|
if strings.TrimSpace(toolType) != "function" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
function, ok := toolMap["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := toolMap["name"]; !ok {
|
||||||
|
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||||
|
toolMap["name"] = name
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["description"]; !ok {
|
||||||
|
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
|
||||||
|
toolMap["description"] = desc
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["parameters"]; !ok {
|
||||||
|
if params, ok := function["parameters"]; ok {
|
||||||
|
toolMap["parameters"] = params
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["strict"]; !ok {
|
||||||
|
if strict, ok := function["strict"]; ok {
|
||||||
|
toolMap["strict"] = strict
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tools[idx] = toolMap
|
||||||
|
}
|
||||||
|
|
||||||
|
if modified {
|
||||||
|
reqBody["tools"] = tools
|
||||||
|
}
|
||||||
|
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
func codexCachePath(filename string) string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
cacheDir := filepath.Join(home, ".opencode", "cache")
|
||||||
|
if filename == "" {
|
||||||
|
return cacheDir
|
||||||
|
}
|
||||||
|
return filepath.Join(cacheDir, filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readFile(path string) (string, bool) {
|
||||||
|
if path == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return string(data), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFile(path, content string) error {
|
||||||
|
if path == "" {
|
||||||
|
return fmt.Errorf("empty cache path")
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, []byte(content), 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadJSON(path string, target any) bool {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, target); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(path string, value any) error {
|
||||||
|
if path == "" {
|
||||||
|
return fmt.Errorf("empty json path")
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, data, 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchWithETag(url, etag string) (string, string, int, error) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", "sub2api-codex")
|
||||||
|
if etag != "" {
|
||||||
|
req.Header.Set("If-None-Match", etag)
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", resp.StatusCode, err
|
||||||
|
}
|
||||||
|
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
|
||||||
|
}
|
||||||
167
backend/internal/service/openai_codex_transform_test.go
Normal file
167
backend/internal/service/openai_codex_transform_test.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||||
|
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
// 未显式设置 store=true,默认为 false。
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 2)
|
||||||
|
|
||||||
|
// 校验 input[0] 为 map,避免断言失败导致测试中断。
|
||||||
|
first, ok := input[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "item_reference", first["type"])
|
||||||
|
require.Equal(t, "ref1", first["id"])
|
||||||
|
|
||||||
|
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||||
|
second, ok := input[1].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "o1", second["id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||||
|
// 续链场景:显式 store=false 不再强制为 true,保持 false。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"store": false,
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||||
|
// 显式 store=true 也会强制为 false。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"store": true,
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
|
||||||
|
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 1)
|
||||||
|
// 校验 input[0] 为 map,避免类型不匹配触发 errcheck。
|
||||||
|
item, ok := input[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
_, hasID := item["id"]
|
||||||
|
require.False(t, hasID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||||
|
input := []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "ref1"},
|
||||||
|
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := filterCodexInput(input, false)
|
||||||
|
require.Len(t, filtered, 1)
|
||||||
|
// 校验 filtered[0] 为 map,确保字段检查可靠。
|
||||||
|
item, ok := filtered[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "text", item["type"])
|
||||||
|
_, hasID := item["id"]
|
||||||
|
require.False(t, hasID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||||
|
// 空 input 应保持为空且不触发异常。
|
||||||
|
setupCodexCache(t)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.1",
|
||||||
|
"input": []any{},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCodexOAuthTransform(reqBody)
|
||||||
|
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, input, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupCodexCache(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// 使用临时 HOME 避免触发网络拉取 header。
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
t.Setenv("HOME", tempDir)
|
||||||
|
|
||||||
|
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
|
||||||
|
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
|
||||||
|
|
||||||
|
meta := map[string]any{
|
||||||
|
"etag": "",
|
||||||
|
"lastFetch": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"lastChecked": time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(meta)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||||
|
}
|
||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -41,6 +42,7 @@ var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
|||||||
var openaiAllowedHeaders = map[string]bool{
|
var openaiAllowedHeaders = map[string]bool{
|
||||||
"accept-language": true,
|
"accept-language": true,
|
||||||
"content-type": true,
|
"content-type": true,
|
||||||
|
"conversation_id": true,
|
||||||
"user-agent": true,
|
"user-agent": true,
|
||||||
"originator": true,
|
"originator": true,
|
||||||
"session_id": true,
|
"session_id": true,
|
||||||
@@ -84,6 +86,7 @@ type OpenAIGatewayService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache GatewayCache
|
cache GatewayCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService
|
||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
billingService *BillingService
|
billingService *BillingService
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
@@ -100,6 +103,7 @@ func NewOpenAIGatewayService(
|
|||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
cache GatewayCache,
|
cache GatewayCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
|
schedulerSnapshot *SchedulerSnapshotService,
|
||||||
concurrencyService *ConcurrencyService,
|
concurrencyService *ConcurrencyService,
|
||||||
billingService *BillingService,
|
billingService *BillingService,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
@@ -114,6 +118,7 @@ func NewOpenAIGatewayService(
|
|||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
schedulerSnapshot: schedulerSnapshot,
|
||||||
concurrencyService: concurrencyService,
|
concurrencyService: concurrencyService,
|
||||||
billingService: billingService,
|
billingService: billingService,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
@@ -158,7 +163,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
// Refresh sticky session TTL
|
// Refresh sticky session TTL
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||||
@@ -169,16 +174,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. Get schedulable OpenAI accounts
|
// 2. Get schedulable OpenAI accounts
|
||||||
var accounts []Account
|
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||||
var err error
|
|
||||||
// 简易模式:忽略分组限制,查询所有可用账号
|
|
||||||
if s.cfg.RunMode == config.RunModeSimple {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
|
||||||
} else if groupID != nil {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
|
||||||
} else {
|
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -300,7 +296,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
@@ -445,6 +441,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
|
||||||
|
return accounts, err
|
||||||
|
}
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
var err error
|
var err error
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
@@ -467,6 +467,13 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
|||||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||||
|
if s.schedulerSnapshot != nil {
|
||||||
|
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
return s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
return s.cfg.Gateway.Scheduling
|
return s.cfg.Gateway.Scheduling
|
||||||
@@ -511,7 +518,7 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -528,33 +535,97 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
// Extract model and stream from parsed body
|
// Extract model and stream from parsed body
|
||||||
reqModel, _ := reqBody["model"].(string)
|
reqModel, _ := reqBody["model"].(string)
|
||||||
reqStream, _ := reqBody["stream"].(bool)
|
reqStream, _ := reqBody["stream"].(bool)
|
||||||
|
promptCacheKey := ""
|
||||||
|
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||||
|
promptCacheKey = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
|
||||||
// Track if body needs re-serialization
|
// Track if body needs re-serialization
|
||||||
bodyModified := false
|
bodyModified := false
|
||||||
originalModel := reqModel
|
originalModel := reqModel
|
||||||
|
|
||||||
// Apply model mapping
|
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||||
|
|
||||||
|
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if mappedModel != reqModel {
|
||||||
|
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
||||||
reqBody["model"] = mappedModel
|
reqBody["model"] = mappedModel
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// For OAuth accounts using ChatGPT internal API:
|
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
||||||
// 1. Add store: false
|
if model, ok := reqBody["model"].(string); ok {
|
||||||
// 2. Normalize input format for Codex API compatibility
|
normalizedModel := normalizeCodexModel(model)
|
||||||
if account.Type == AccountTypeOAuth {
|
if normalizedModel != "" && normalizedModel != model {
|
||||||
reqBody["store"] = false
|
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
||||||
bodyModified = true
|
model, normalizedModel, account.Name, account.Type, isCodexCLI)
|
||||||
|
reqBody["model"] = normalizedModel
|
||||||
// Normalize input format: convert AI SDK multi-part content format to simplified format
|
mappedModel = normalizedModel
|
||||||
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
|
|
||||||
// Codex API expects: {"content": "..."}
|
|
||||||
if normalizeInputForCodexAPI(reqBody) {
|
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
||||||
|
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
||||||
|
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
|
||||||
|
reasoning["effort"] = "none"
|
||||||
|
bodyModified = true
|
||||||
|
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
||||||
|
codexResult := applyCodexOAuthTransform(reqBody)
|
||||||
|
if codexResult.Modified {
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
|
if codexResult.NormalizedModel != "" {
|
||||||
|
mappedModel = codexResult.NormalizedModel
|
||||||
|
}
|
||||||
|
if codexResult.PromptCacheKey != "" {
|
||||||
|
promptCacheKey = codexResult.PromptCacheKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle max_output_tokens based on platform and account type
|
||||||
|
if !isCodexCLI {
|
||||||
|
if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens {
|
||||||
|
switch account.Platform {
|
||||||
|
case PlatformOpenAI:
|
||||||
|
// For OpenAI API Key, remove max_output_tokens (not supported)
|
||||||
|
// For OpenAI OAuth (Responses API), keep it (supported)
|
||||||
|
if account.Type == AccountTypeAPIKey {
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
|
case PlatformAnthropic:
|
||||||
|
// For Anthropic (Claude), convert to max_tokens
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens {
|
||||||
|
reqBody["max_tokens"] = maxOutputTokens
|
||||||
|
}
|
||||||
|
bodyModified = true
|
||||||
|
case PlatformGemini:
|
||||||
|
// For Gemini, remove (will be handled by Gemini-specific transform)
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
bodyModified = true
|
||||||
|
default:
|
||||||
|
// For unknown platforms, remove to be safe
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also handle max_completion_tokens (similar logic)
|
||||||
|
if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens {
|
||||||
|
if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI {
|
||||||
|
delete(reqBody, "max_completion_tokens")
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Re-serialize body only if modified
|
// Re-serialize body only if modified
|
||||||
if bodyModified {
|
if bodyModified {
|
||||||
var err error
|
var err error
|
||||||
@@ -571,7 +642,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build upstream request
|
// Build upstream request
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -585,13 +656,53 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
// Send request
|
// Send request
|
||||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream request failed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
// Handle error response
|
// Handle error response
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
s.handleFailoverSideEffects(ctx, resp, account)
|
s.handleFailoverSideEffects(ctx, resp, account)
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
@@ -632,7 +743,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
|
||||||
// Determine target URL based on account type
|
// Determine target URL based on account type
|
||||||
var targetURL string
|
var targetURL string
|
||||||
switch account.Type {
|
switch account.Type {
|
||||||
@@ -672,12 +783,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
if chatgptAccountID != "" {
|
if chatgptAccountID != "" {
|
||||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||||
}
|
}
|
||||||
// Set accept header based on stream mode
|
|
||||||
if isStream {
|
|
||||||
req.Header.Set("accept", "text/event-stream")
|
|
||||||
} else {
|
|
||||||
req.Header.Set("accept", "application/json")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Whitelist passthrough headers
|
// Whitelist passthrough headers
|
||||||
@@ -689,6 +794,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if account.Type == AccountTypeOAuth {
|
||||||
|
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||||
|
if isCodexCLI {
|
||||||
|
req.Header.Set("originator", "codex_cli_rs")
|
||||||
|
} else {
|
||||||
|
req.Header.Set("originator", "opencode")
|
||||||
|
}
|
||||||
|
req.Header.Set("accept", "text/event-stream")
|
||||||
|
if promptCacheKey != "" {
|
||||||
|
req.Header.Set("conversation_id", promptCacheKey)
|
||||||
|
req.Header.Set("session_id", promptCacheKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply custom User-Agent if configured
|
// Apply custom User-Agent if configured
|
||||||
customUA := account.GetOpenAIUserAgent()
|
customUA := account.GetOpenAIUserAgent()
|
||||||
@@ -705,17 +823,52 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
|
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(body), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
log.Printf(
|
||||||
|
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
|
||||||
|
resp.StatusCode,
|
||||||
|
account.ID,
|
||||||
|
account.Platform,
|
||||||
|
account.Type,
|
||||||
|
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// Check custom error codes
|
// Check custom error codes
|
||||||
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"type": "upstream_error",
|
"type": "upstream_error",
|
||||||
"message": "Upstream gateway error",
|
"message": "Upstream gateway error",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle upstream error (mark account status)
|
// Handle upstream error (mark account status)
|
||||||
@@ -723,6 +876,19 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
|||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
}
|
}
|
||||||
|
kind := "http_error"
|
||||||
|
if shouldDisable {
|
||||||
|
kind = "failover"
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: kind,
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
if shouldDisable {
|
if shouldDisable {
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
@@ -761,7 +927,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// openaiStreamingResult streaming response result
|
// openaiStreamingResult streaming response result
|
||||||
@@ -933,6 +1102,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||||
|
}
|
||||||
sendErrorEvent("stream_timeout")
|
sendErrorEvent("stream_timeout")
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
@@ -1016,6 +1189,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.Type == AccountTypeOAuth {
|
||||||
|
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
|
||||||
|
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
|
||||||
|
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Parse usage
|
// Parse usage
|
||||||
var response struct {
|
var response struct {
|
||||||
Usage struct {
|
Usage struct {
|
||||||
@@ -1055,6 +1235,110 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
|||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isEventStreamResponse(header http.Header) bool {
|
||||||
|
contentType := strings.ToLower(header.Get("Content-Type"))
|
||||||
|
return strings.Contains(contentType, "text/event-stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||||
|
bodyText := string(body)
|
||||||
|
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
||||||
|
|
||||||
|
usage := &OpenAIUsage{}
|
||||||
|
if ok {
|
||||||
|
var response struct {
|
||||||
|
Usage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
InputTokenDetails struct {
|
||||||
|
CachedTokens int `json:"cached_tokens"`
|
||||||
|
} `json:"input_tokens_details"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(finalResponse, &response); err == nil {
|
||||||
|
usage.InputTokens = response.Usage.InputTokens
|
||||||
|
usage.OutputTokens = response.Usage.OutputTokens
|
||||||
|
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
|
||||||
|
}
|
||||||
|
body = finalResponse
|
||||||
|
if originalModel != mappedModel {
|
||||||
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
usage = s.parseSSEUsageFromBody(bodyText)
|
||||||
|
if originalModel != mappedModel {
|
||||||
|
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
||||||
|
}
|
||||||
|
body = []byte(bodyText)
|
||||||
|
}
|
||||||
|
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||||
|
|
||||||
|
contentType := "application/json; charset=utf-8"
|
||||||
|
if !ok {
|
||||||
|
contentType = resp.Header.Get("Content-Type")
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = "text/event-stream"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Data(resp.StatusCode, contentType, body)
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
if !openaiSSEDataRe.MatchString(line) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||||
|
if data == "" || data == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var event struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Response json.RawMessage `json:"response"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal([]byte(data), &event) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if event.Type == "response.done" || event.Type == "response.completed" {
|
||||||
|
if len(event.Response) > 0 {
|
||||||
|
return event.Response, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||||
|
usage := &OpenAIUsage{}
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
if !openaiSSEDataRe.MatchString(line) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||||
|
if data == "" || data == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.parseSSEUsage(data, usage)
|
||||||
|
}
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
if !openaiSSEDataRe.MatchString(line) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
||||||
|
}
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||||
@@ -1094,101 +1378,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
|||||||
return newBody
|
return newBody
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
|
|
||||||
// that the ChatGPT internal Codex API expects.
|
|
||||||
//
|
|
||||||
// AI SDK sends content as an array of typed objects:
|
|
||||||
//
|
|
||||||
// {"content": [{"type": "input_text", "text": "hello"}]}
|
|
||||||
//
|
|
||||||
// ChatGPT Codex API expects content as a simple string:
|
|
||||||
//
|
|
||||||
// {"content": "hello"}
|
|
||||||
//
|
|
||||||
// This function modifies reqBody in-place and returns true if any modification was made.
|
|
||||||
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
|
|
||||||
input, ok := reqBody["input"]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle case where input is a simple string (already compatible)
|
|
||||||
if _, isString := input.(string); isString {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle case where input is an array of messages
|
|
||||||
inputArray, ok := input.([]any)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
modified := false
|
|
||||||
for _, item := range inputArray {
|
|
||||||
message, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
content, ok := message["content"]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If content is already a string, no conversion needed
|
|
||||||
if _, isString := content.(string); isString {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If content is an array (AI SDK format), convert to string
|
|
||||||
contentArray, ok := content.([]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract text from content array
|
|
||||||
var textParts []string
|
|
||||||
for _, part := range contentArray {
|
|
||||||
partMap, ok := part.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle different content types
|
|
||||||
partType, _ := partMap["type"].(string)
|
|
||||||
switch partType {
|
|
||||||
case "input_text", "text":
|
|
||||||
// Extract text from input_text or text type
|
|
||||||
if text, ok := partMap["text"].(string); ok {
|
|
||||||
textParts = append(textParts, text)
|
|
||||||
}
|
|
||||||
case "input_image", "image":
|
|
||||||
// For images, we need to preserve the original format
|
|
||||||
// as ChatGPT Codex API may support images in a different way
|
|
||||||
// For now, skip image parts (they will be lost in conversion)
|
|
||||||
// TODO: Consider preserving image data or handling it separately
|
|
||||||
continue
|
|
||||||
case "input_file", "file":
|
|
||||||
// Similar to images, file inputs may need special handling
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
// For unknown types, try to extract text if available
|
|
||||||
if text, ok := partMap["text"].(string); ok {
|
|
||||||
textParts = append(textParts, text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert content array to string
|
|
||||||
if len(textParts) > 0 {
|
|
||||||
message["content"] = strings.Join(textParts, "\n")
|
|
||||||
modified = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return modified
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenAIRecordUsageInput input for recording usage
|
// OpenAIRecordUsageInput input for recording usage
|
||||||
type OpenAIRecordUsageInput struct {
|
type OpenAIRecordUsageInput struct {
|
||||||
Result *OpenAIForwardResult
|
Result *OpenAIForwardResult
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
|||||||
Credentials: map[string]any{"base_url": "://invalid-url"},
|
Credentials: map[string]any{"base_url": "://invalid-url"},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
|
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
||||||
}
|
}
|
||||||
|
|||||||
213
backend/internal/service/openai_tool_continuation.go
Normal file
213
backend/internal/service/openai_tool_continuation.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||||
|
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||||
|
// 或显式声明 tools/tool_choice。
|
||||||
|
func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasNonEmptyString(reqBody["previous_response_id"]) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasToolsSignal(reqBody) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasToolChoiceSignal(reqBody) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if inputHasType(reqBody, "function_call_output") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if inputHasType(reqBody, "item_reference") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||||
|
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return inputHasType(reqBody, "function_call_output")
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||||
|
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||||
|
func HasToolCallContext(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "tool_call" && itemType != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||||
|
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||||
|
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||||
|
if reqBody == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ids := make(map[string]struct{})
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||||
|
ids[callID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(ids))
|
||||||
|
for id := range ids {
|
||||||
|
result = append(result, id)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||||
|
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||||
|
if reqBody == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID, _ := itemMap["call_id"].(string)
|
||||||
|
if strings.TrimSpace(callID) == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
|
||||||
|
// 用于仅依赖引用项完成续链场景的校验。
|
||||||
|
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||||
|
if reqBody == nil || len(callIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
referenceIDs := make(map[string]struct{})
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "item_reference" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idValue, _ := itemMap["id"].(string)
|
||||||
|
idValue = strings.TrimSpace(idValue)
|
||||||
|
if idValue == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
referenceIDs[idValue] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(referenceIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, callID := range callIDs {
|
||||||
|
if _, ok := referenceIDs[callID]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// inputHasType 判断 input 中是否存在指定类型的 item。
|
||||||
|
func inputHasType(reqBody map[string]any, want string) bool {
|
||||||
|
input, ok := reqBody["input"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType == want {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasNonEmptyString 判断字段是否为非空字符串。
|
||||||
|
func hasNonEmptyString(value any) bool {
|
||||||
|
stringValue, ok := value.(string)
|
||||||
|
return ok && strings.TrimSpace(stringValue) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
|
||||||
|
func hasToolsSignal(reqBody map[string]any) bool {
|
||||||
|
raw, exists := reqBody["tools"]
|
||||||
|
if !exists || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if tools, ok := raw.([]any); ok {
|
||||||
|
return len(tools) > 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。
|
||||||
|
func hasToolChoiceSignal(reqBody map[string]any) bool {
|
||||||
|
raw, exists := reqBody["tool_choice"]
|
||||||
|
if !exists || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch value := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(value) != ""
|
||||||
|
case map[string]any:
|
||||||
|
return len(value) > 0
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNeedsToolContinuationSignals(t *testing.T) {
|
||||||
|
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
body map[string]any
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "nil", body: nil, want: false},
|
||||||
|
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
|
||||||
|
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
|
||||||
|
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
|
||||||
|
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
|
||||||
|
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
|
||||||
|
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
|
||||||
|
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
|
||||||
|
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
|
||||||
|
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
|
||||||
|
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
|
||||||
|
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFunctionCallOutput(t *testing.T) {
|
||||||
|
// 仅当 input 中存在 function_call_output 才视为续链输出。
|
||||||
|
require.False(t, HasFunctionCallOutput(nil))
|
||||||
|
require.True(t, HasFunctionCallOutput(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasFunctionCallOutput(map[string]any{
|
||||||
|
"input": "text",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasToolCallContext(t *testing.T) {
|
||||||
|
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
|
||||||
|
require.False(t, HasToolCallContext(nil))
|
||||||
|
require.True(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
|
||||||
|
}))
|
||||||
|
require.True(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasToolCallContext(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "tool_call"}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionCallOutputCallIDs(t *testing.T) {
|
||||||
|
// 仅提取非空 call_id,去重后返回。
|
||||||
|
require.Empty(t, FunctionCallOutputCallIDs(nil))
|
||||||
|
callIDs := FunctionCallOutputCallIDs(map[string]any{
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": ""},
|
||||||
|
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.ElementsMatch(t, []string{"call_1"}, callIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
|
||||||
|
require.False(t, HasFunctionCallOutputMissingCallID(nil))
|
||||||
|
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||||
|
}))
|
||||||
|
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||||
|
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasItemReferenceForCallIDs(t *testing.T) {
|
||||||
|
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
|
||||||
|
req := map[string]any{
|
||||||
|
"input": []any{
|
||||||
|
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||||
|
map[string]any{"type": "item_reference", "id": "call_2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
|
||||||
|
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
|
||||||
|
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
|
||||||
|
}
|
||||||
194
backend/internal/service/ops_account_availability.go
Normal file
194
backend/internal/service/ops_account_availability.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAccountAvailabilityStats returns current account availability stats.
|
||||||
|
//
|
||||||
|
// Query-level filtering is intentionally limited to platform/group to match the dashboard scope.
|
||||||
|
func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFilter string, groupIDFilter *int64) (
|
||||||
|
map[string]*PlatformAvailability,
|
||||||
|
map[int64]*GroupAvailability,
|
||||||
|
map[int64]*AccountAvailability,
|
||||||
|
*time.Time,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if groupIDFilter != nil && *groupIDFilter > 0 {
|
||||||
|
filtered := make([]Account, 0, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
for _, grp := range acc.Groups {
|
||||||
|
if grp != nil && grp.ID == *groupIDFilter {
|
||||||
|
filtered = append(filtered, acc)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
accounts = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
collectedAt := now
|
||||||
|
|
||||||
|
platform := make(map[string]*PlatformAvailability)
|
||||||
|
group := make(map[int64]*GroupAvailability)
|
||||||
|
account := make(map[int64]*AccountAvailability)
|
||||||
|
|
||||||
|
for _, acc := range accounts {
|
||||||
|
if acc.ID <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
isTempUnsched := false
|
||||||
|
if acc.TempUnschedulableUntil != nil && now.Before(*acc.TempUnschedulableUntil) {
|
||||||
|
isTempUnsched = true
|
||||||
|
}
|
||||||
|
|
||||||
|
isRateLimited := acc.RateLimitResetAt != nil && now.Before(*acc.RateLimitResetAt)
|
||||||
|
isOverloaded := acc.OverloadUntil != nil && now.Before(*acc.OverloadUntil)
|
||||||
|
hasError := acc.Status == StatusError
|
||||||
|
|
||||||
|
// Normalize exclusive status flags so the UI doesn't show conflicting badges.
|
||||||
|
if hasError {
|
||||||
|
isRateLimited = false
|
||||||
|
isOverloaded = false
|
||||||
|
}
|
||||||
|
|
||||||
|
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
||||||
|
|
||||||
|
if acc.Platform != "" {
|
||||||
|
if _, ok := platform[acc.Platform]; !ok {
|
||||||
|
platform[acc.Platform] = &PlatformAvailability{
|
||||||
|
Platform: acc.Platform,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p := platform[acc.Platform]
|
||||||
|
p.TotalAccounts++
|
||||||
|
if isAvailable {
|
||||||
|
p.AvailableCount++
|
||||||
|
}
|
||||||
|
if isRateLimited {
|
||||||
|
p.RateLimitCount++
|
||||||
|
}
|
||||||
|
if hasError {
|
||||||
|
p.ErrorCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, grp := range acc.Groups {
|
||||||
|
if grp == nil || grp.ID <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := group[grp.ID]; !ok {
|
||||||
|
group[grp.ID] = &GroupAvailability{
|
||||||
|
GroupID: grp.ID,
|
||||||
|
GroupName: grp.Name,
|
||||||
|
Platform: grp.Platform,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g := group[grp.ID]
|
||||||
|
g.TotalAccounts++
|
||||||
|
if isAvailable {
|
||||||
|
g.AvailableCount++
|
||||||
|
}
|
||||||
|
if isRateLimited {
|
||||||
|
g.RateLimitCount++
|
||||||
|
}
|
||||||
|
if hasError {
|
||||||
|
g.ErrorCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
displayGroupID := int64(0)
|
||||||
|
displayGroupName := ""
|
||||||
|
if len(acc.Groups) > 0 && acc.Groups[0] != nil {
|
||||||
|
displayGroupID = acc.Groups[0].ID
|
||||||
|
displayGroupName = acc.Groups[0].Name
|
||||||
|
}
|
||||||
|
|
||||||
|
item := &AccountAvailability{
|
||||||
|
AccountID: acc.ID,
|
||||||
|
AccountName: acc.Name,
|
||||||
|
Platform: acc.Platform,
|
||||||
|
GroupID: displayGroupID,
|
||||||
|
GroupName: displayGroupName,
|
||||||
|
Status: acc.Status,
|
||||||
|
|
||||||
|
IsAvailable: isAvailable,
|
||||||
|
IsRateLimited: isRateLimited,
|
||||||
|
IsOverloaded: isOverloaded,
|
||||||
|
HasError: hasError,
|
||||||
|
|
||||||
|
ErrorMessage: acc.ErrorMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isRateLimited && acc.RateLimitResetAt != nil {
|
||||||
|
item.RateLimitResetAt = acc.RateLimitResetAt
|
||||||
|
remainingSec := int64(time.Until(*acc.RateLimitResetAt).Seconds())
|
||||||
|
if remainingSec > 0 {
|
||||||
|
item.RateLimitRemainingSec = &remainingSec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isOverloaded && acc.OverloadUntil != nil {
|
||||||
|
item.OverloadUntil = acc.OverloadUntil
|
||||||
|
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
||||||
|
if remainingSec > 0 {
|
||||||
|
item.OverloadRemainingSec = &remainingSec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isTempUnsched && acc.TempUnschedulableUntil != nil {
|
||||||
|
item.TempUnschedulableUntil = acc.TempUnschedulableUntil
|
||||||
|
}
|
||||||
|
|
||||||
|
account[acc.ID] = item
|
||||||
|
}
|
||||||
|
|
||||||
|
return platform, group, account, &collectedAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpsAccountAvailability struct {
|
||||||
|
Group *GroupAvailability
|
||||||
|
Accounts map[int64]*AccountAvailability
|
||||||
|
CollectedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsService) GetAccountAvailability(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, errors.New("ops service is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.getAccountAvailability != nil {
|
||||||
|
return s.getAccountAvailability(ctx, platformFilter, groupIDFilter)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, groupStats, accountStats, collectedAt, err := s.GetAccountAvailabilityStats(ctx, platformFilter, groupIDFilter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var group *GroupAvailability
|
||||||
|
if groupIDFilter != nil && *groupIDFilter > 0 {
|
||||||
|
group = groupStats[*groupIDFilter]
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountStats == nil {
|
||||||
|
accountStats = map[int64]*AccountAvailability{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OpsAccountAvailability{
|
||||||
|
Group: group,
|
||||||
|
Accounts: accountStats,
|
||||||
|
CollectedAt: collectedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
46
backend/internal/service/ops_advisory_lock.go
Normal file
46
backend/internal/service/ops_advisory_lock.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"hash/fnv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func hashAdvisoryLockID(key string) int64 {
|
||||||
|
h := fnv.New64a()
|
||||||
|
_, _ = h.Write([]byte(key))
|
||||||
|
return int64(h.Sum64())
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryAcquireDBAdvisoryLock(ctx context.Context, db *sql.DB, lockID int64) (func(), bool) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := db.Conn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
acquired := false
|
||||||
|
if err := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&acquired); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if !acquired {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
release := func() {
|
||||||
|
unlockCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, _ = conn.ExecContext(unlockCtx, "SELECT pg_advisory_unlock($1)", lockID)
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
return release, true
|
||||||
|
}
|
||||||
443
backend/internal/service/ops_aggregation_service.go
Normal file
443
backend/internal/service/ops_aggregation_service.go
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
opsAggHourlyJobName = "ops_preaggregation_hourly"
|
||||||
|
opsAggDailyJobName = "ops_preaggregation_daily"
|
||||||
|
|
||||||
|
opsAggHourlyInterval = 10 * time.Minute
|
||||||
|
opsAggDailyInterval = 1 * time.Hour
|
||||||
|
|
||||||
|
// Keep in sync with ops retention target (vNext default 30d).
|
||||||
|
opsAggBackfillWindow = 30 * 24 * time.Hour
|
||||||
|
|
||||||
|
// Recompute overlap to absorb late-arriving rows near boundaries.
|
||||||
|
opsAggHourlyOverlap = 2 * time.Hour
|
||||||
|
opsAggDailyOverlap = 48 * time.Hour
|
||||||
|
|
||||||
|
opsAggHourlyChunk = 24 * time.Hour
|
||||||
|
opsAggDailyChunk = 7 * 24 * time.Hour
|
||||||
|
|
||||||
|
// Delay around boundaries (e.g. 10:00..10:05) to avoid aggregating buckets
|
||||||
|
// that may still receive late inserts.
|
||||||
|
opsAggSafeDelay = 5 * time.Minute
|
||||||
|
|
||||||
|
opsAggMaxQueryTimeout = 3 * time.Second
|
||||||
|
opsAggHourlyTimeout = 5 * time.Minute
|
||||||
|
opsAggDailyTimeout = 2 * time.Minute
|
||||||
|
|
||||||
|
opsAggHourlyLeaderLockKey = "ops:aggregation:hourly:leader"
|
||||||
|
opsAggDailyLeaderLockKey = "ops:aggregation:daily:leader"
|
||||||
|
|
||||||
|
opsAggHourlyLeaderLockTTL = 15 * time.Minute
|
||||||
|
opsAggDailyLeaderLockTTL = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpsAggregationService periodically backfills ops_metrics_hourly / ops_metrics_daily
|
||||||
|
// for stable long-window dashboard queries.
|
||||||
|
//
|
||||||
|
// It is safe to run in multi-replica deployments when Redis is available (leader lock).
|
||||||
|
type OpsAggregationService struct {
|
||||||
|
opsRepo OpsRepository
|
||||||
|
settingRepo SettingRepository
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
db *sql.DB
|
||||||
|
redisClient *redis.Client
|
||||||
|
instanceID string
|
||||||
|
|
||||||
|
stopCh chan struct{}
|
||||||
|
startOnce sync.Once
|
||||||
|
stopOnce sync.Once
|
||||||
|
|
||||||
|
hourlyMu sync.Mutex
|
||||||
|
dailyMu sync.Mutex
|
||||||
|
|
||||||
|
skipLogMu sync.Mutex
|
||||||
|
skipLogAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpsAggregationService(
|
||||||
|
opsRepo OpsRepository,
|
||||||
|
settingRepo SettingRepository,
|
||||||
|
db *sql.DB,
|
||||||
|
redisClient *redis.Client,
|
||||||
|
cfg *config.Config,
|
||||||
|
) *OpsAggregationService {
|
||||||
|
return &OpsAggregationService{
|
||||||
|
opsRepo: opsRepo,
|
||||||
|
settingRepo: settingRepo,
|
||||||
|
cfg: cfg,
|
||||||
|
db: db,
|
||||||
|
redisClient: redisClient,
|
||||||
|
instanceID: uuid.NewString(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) Start() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.startOnce.Do(func() {
|
||||||
|
if s.stopCh == nil {
|
||||||
|
s.stopCh = make(chan struct{})
|
||||||
|
}
|
||||||
|
go s.hourlyLoop()
|
||||||
|
go s.dailyLoop()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) Stop() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stopOnce.Do(func() {
|
||||||
|
if s.stopCh != nil {
|
||||||
|
close(s.stopCh)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) hourlyLoop() {
|
||||||
|
// First run immediately.
|
||||||
|
s.aggregateHourly()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(opsAggHourlyInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
s.aggregateHourly()
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) dailyLoop() {
|
||||||
|
// First run immediately.
|
||||||
|
s.aggregateDaily()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(opsAggDailyInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
s.aggregateDaily()
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) aggregateHourly() {
|
||||||
|
if s == nil || s.opsRepo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.cfg != nil {
|
||||||
|
if !s.cfg.Ops.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.cfg.Ops.Aggregation.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), opsAggHourlyTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if !s.isMonitoringEnabled(ctx) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
release, ok := s.tryAcquireLeaderLock(ctx, opsAggHourlyLeaderLockKey, opsAggHourlyLeaderLockTTL, "[OpsAggregation][hourly]")
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if release != nil {
|
||||||
|
defer release()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.hourlyMu.Lock()
|
||||||
|
defer s.hourlyMu.Unlock()
|
||||||
|
|
||||||
|
startedAt := time.Now().UTC()
|
||||||
|
runAt := startedAt
|
||||||
|
|
||||||
|
// Aggregate stable full hours only.
|
||||||
|
end := utcFloorToHour(time.Now().UTC().Add(-opsAggSafeDelay))
|
||||||
|
start := end.Add(-opsAggBackfillWindow)
|
||||||
|
|
||||||
|
// Resume from the latest bucket with overlap.
|
||||||
|
{
|
||||||
|
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
|
||||||
|
latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax)
|
||||||
|
cancelMax()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err)
|
||||||
|
} else if ok {
|
||||||
|
candidate := latest.Add(-opsAggHourlyOverlap)
|
||||||
|
if candidate.After(start) {
|
||||||
|
start = candidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start = utcFloorToHour(start)
|
||||||
|
if !start.Before(end) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var aggErr error
|
||||||
|
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggHourlyChunk) {
|
||||||
|
chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end)
|
||||||
|
if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil {
|
||||||
|
aggErr = err
|
||||||
|
log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishedAt := time.Now().UTC()
|
||||||
|
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||||
|
dur := durationMs
|
||||||
|
|
||||||
|
if aggErr != nil {
|
||||||
|
msg := truncateString(aggErr.Error(), 2048)
|
||||||
|
errAt := finishedAt
|
||||||
|
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer hbCancel()
|
||||||
|
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||||
|
JobName: opsAggHourlyJobName,
|
||||||
|
LastRunAt: &runAt,
|
||||||
|
LastErrorAt: &errAt,
|
||||||
|
LastError: &msg,
|
||||||
|
LastDurationMs: &dur,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
successAt := finishedAt
|
||||||
|
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer hbCancel()
|
||||||
|
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||||
|
JobName: opsAggHourlyJobName,
|
||||||
|
LastRunAt: &runAt,
|
||||||
|
LastSuccessAt: &successAt,
|
||||||
|
LastDurationMs: &dur,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) aggregateDaily() {
|
||||||
|
if s == nil || s.opsRepo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.cfg != nil {
|
||||||
|
if !s.cfg.Ops.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.cfg.Ops.Aggregation.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), opsAggDailyTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if !s.isMonitoringEnabled(ctx) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
release, ok := s.tryAcquireLeaderLock(ctx, opsAggDailyLeaderLockKey, opsAggDailyLeaderLockTTL, "[OpsAggregation][daily]")
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if release != nil {
|
||||||
|
defer release()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dailyMu.Lock()
|
||||||
|
defer s.dailyMu.Unlock()
|
||||||
|
|
||||||
|
startedAt := time.Now().UTC()
|
||||||
|
runAt := startedAt
|
||||||
|
|
||||||
|
end := utcFloorToDay(time.Now().UTC())
|
||||||
|
start := end.Add(-opsAggBackfillWindow)
|
||||||
|
|
||||||
|
{
|
||||||
|
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
|
||||||
|
latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax)
|
||||||
|
cancelMax()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err)
|
||||||
|
} else if ok {
|
||||||
|
candidate := latest.Add(-opsAggDailyOverlap)
|
||||||
|
if candidate.After(start) {
|
||||||
|
start = candidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start = utcFloorToDay(start)
|
||||||
|
if !start.Before(end) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var aggErr error
|
||||||
|
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggDailyChunk) {
|
||||||
|
chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end)
|
||||||
|
if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil {
|
||||||
|
aggErr = err
|
||||||
|
log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishedAt := time.Now().UTC()
|
||||||
|
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||||
|
dur := durationMs
|
||||||
|
|
||||||
|
if aggErr != nil {
|
||||||
|
msg := truncateString(aggErr.Error(), 2048)
|
||||||
|
errAt := finishedAt
|
||||||
|
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer hbCancel()
|
||||||
|
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||||
|
JobName: opsAggDailyJobName,
|
||||||
|
LastRunAt: &runAt,
|
||||||
|
LastErrorAt: &errAt,
|
||||||
|
LastError: &msg,
|
||||||
|
LastDurationMs: &dur,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
successAt := finishedAt
|
||||||
|
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer hbCancel()
|
||||||
|
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||||
|
JobName: opsAggDailyJobName,
|
||||||
|
LastRunAt: &runAt,
|
||||||
|
LastSuccessAt: &successAt,
|
||||||
|
LastDurationMs: &dur,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) isMonitoringEnabled(ctx context.Context) bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.settingRepo == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrSettingNotFound) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||||
|
case "false", "0", "off", "disabled":
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var opsAggReleaseScript = redis.NewScript(`
|
||||||
|
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||||
|
return redis.call("DEL", KEYS[1])
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
`)
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) tryAcquireLeaderLock(ctx context.Context, key string, ttl time.Duration, logPrefix string) (func(), bool) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer Redis leader lock when available (multi-instance), but avoid stampeding
|
||||||
|
// the DB when Redis is flaky by falling back to a DB advisory lock.
|
||||||
|
if s.redisClient != nil {
|
||||||
|
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||||
|
if err == nil {
|
||||||
|
if !ok {
|
||||||
|
s.maybeLogSkip(logPrefix)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
release := func() {
|
||||||
|
ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, _ = opsAggReleaseScript.Run(ctx2, s.redisClient, []string{key}, s.instanceID).Result()
|
||||||
|
}
|
||||||
|
return release, true
|
||||||
|
}
|
||||||
|
// Redis error: fall through to DB advisory lock.
|
||||||
|
}
|
||||||
|
|
||||||
|
release, ok := tryAcquireDBAdvisoryLock(ctx, s.db, hashAdvisoryLockID(key))
|
||||||
|
if !ok {
|
||||||
|
s.maybeLogSkip(logPrefix)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return release, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAggregationService) maybeLogSkip(prefix string) {
|
||||||
|
s.skipLogMu.Lock()
|
||||||
|
defer s.skipLogMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < time.Minute {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.skipLogAt = now
|
||||||
|
if prefix == "" {
|
||||||
|
prefix = "[OpsAggregation]"
|
||||||
|
}
|
||||||
|
log.Printf("%s leader lock held by another instance; skipping", prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func utcFloorToHour(t time.Time) time.Time {
|
||||||
|
return t.UTC().Truncate(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
func utcFloorToDay(t time.Time) time.Time {
|
||||||
|
u := t.UTC()
|
||||||
|
y, m, d := u.Date()
|
||||||
|
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func minTime(a, b time.Time) time.Time {
|
||||||
|
if a.Before(b) {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
913
backend/internal/service/ops_alert_evaluator_service.go
Normal file
913
backend/internal/service/ops_alert_evaluator_service.go
Normal file
@@ -0,0 +1,913 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
opsAlertEvaluatorJobName = "ops_alert_evaluator"
|
||||||
|
|
||||||
|
opsAlertEvaluatorTimeout = 45 * time.Second
|
||||||
|
opsAlertEvaluatorLeaderLockKey = "ops:alert:evaluator:leader"
|
||||||
|
opsAlertEvaluatorLeaderLockTTL = 90 * time.Second
|
||||||
|
opsAlertEvaluatorSkipLogInterval = 1 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
var opsAlertEvaluatorReleaseScript = redis.NewScript(`
|
||||||
|
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||||
|
return redis.call("DEL", KEYS[1])
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
`)
|
||||||
|
|
||||||
|
type OpsAlertEvaluatorService struct {
|
||||||
|
opsService *OpsService
|
||||||
|
opsRepo OpsRepository
|
||||||
|
emailService *EmailService
|
||||||
|
|
||||||
|
redisClient *redis.Client
|
||||||
|
cfg *config.Config
|
||||||
|
instanceID string
|
||||||
|
|
||||||
|
stopCh chan struct{}
|
||||||
|
startOnce sync.Once
|
||||||
|
stopOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
ruleStates map[int64]*opsAlertRuleState
|
||||||
|
|
||||||
|
emailLimiter *slidingWindowLimiter
|
||||||
|
|
||||||
|
skipLogMu sync.Mutex
|
||||||
|
skipLogAt time.Time
|
||||||
|
|
||||||
|
warnNoRedisOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type opsAlertRuleState struct {
|
||||||
|
LastEvaluatedAt time.Time
|
||||||
|
ConsecutiveBreaches int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpsAlertEvaluatorService(
|
||||||
|
opsService *OpsService,
|
||||||
|
opsRepo OpsRepository,
|
||||||
|
emailService *EmailService,
|
||||||
|
redisClient *redis.Client,
|
||||||
|
cfg *config.Config,
|
||||||
|
) *OpsAlertEvaluatorService {
|
||||||
|
return &OpsAlertEvaluatorService{
|
||||||
|
opsService: opsService,
|
||||||
|
opsRepo: opsRepo,
|
||||||
|
emailService: emailService,
|
||||||
|
redisClient: redisClient,
|
||||||
|
cfg: cfg,
|
||||||
|
instanceID: uuid.NewString(),
|
||||||
|
ruleStates: map[int64]*opsAlertRuleState{},
|
||||||
|
emailLimiter: newSlidingWindowLimiter(0, time.Hour),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) Start() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.startOnce.Do(func() {
|
||||||
|
if s.stopCh == nil {
|
||||||
|
s.stopCh = make(chan struct{})
|
||||||
|
}
|
||||||
|
go s.run()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) Stop() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stopOnce.Do(func() {
|
||||||
|
if s.stopCh != nil {
|
||||||
|
close(s.stopCh)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
s.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) run() {
|
||||||
|
s.wg.Add(1)
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
// Start immediately to produce early feedback in ops dashboard.
|
||||||
|
timer := time.NewTimer(0)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
interval := s.getInterval()
|
||||||
|
s.evaluateOnce(interval)
|
||||||
|
timer.Reset(interval)
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) getInterval() time.Duration {
|
||||||
|
// Default.
|
||||||
|
interval := 60 * time.Second
|
||||||
|
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return interval
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cfg, err := s.opsService.GetOpsAlertRuntimeSettings(ctx)
|
||||||
|
if err != nil || cfg == nil {
|
||||||
|
return interval
|
||||||
|
}
|
||||||
|
if cfg.EvaluationIntervalSeconds <= 0 {
|
||||||
|
return interval
|
||||||
|
}
|
||||||
|
if cfg.EvaluationIntervalSeconds < 1 {
|
||||||
|
return interval
|
||||||
|
}
|
||||||
|
if cfg.EvaluationIntervalSeconds > int((24 * time.Hour).Seconds()) {
|
||||||
|
return interval
|
||||||
|
}
|
||||||
|
return time.Duration(cfg.EvaluationIntervalSeconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||||
|
if s == nil || s.opsRepo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), opsAlertEvaluatorTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if s.opsService != nil && !s.opsService.IsMonitoringEnabled(ctx) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
runtimeCfg := defaultOpsAlertRuntimeSettings()
|
||||||
|
if s.opsService != nil {
|
||||||
|
if loaded, err := s.opsService.GetOpsAlertRuntimeSettings(ctx); err == nil && loaded != nil {
|
||||||
|
runtimeCfg = loaded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
release, ok := s.tryAcquireLeaderLock(ctx, runtimeCfg.DistributedLock)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if release != nil {
|
||||||
|
defer release()
|
||||||
|
}
|
||||||
|
|
||||||
|
startedAt := time.Now().UTC()
|
||||||
|
runAt := startedAt
|
||||||
|
|
||||||
|
rules, err := s.opsRepo.ListAlertRules(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||||
|
log.Printf("[OpsAlertEvaluator] list rules failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
safeEnd := now.Truncate(time.Minute)
|
||||||
|
if safeEnd.IsZero() {
|
||||||
|
safeEnd = now
|
||||||
|
}
|
||||||
|
|
||||||
|
systemMetrics, _ := s.opsRepo.GetLatestSystemMetrics(ctx, 1)
|
||||||
|
|
||||||
|
// Cleanup stale state for removed rules.
|
||||||
|
s.pruneRuleStates(rules)
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule == nil || !rule.Enabled || rule.ID <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
scopePlatform, scopeGroupID := parseOpsAlertRuleScope(rule.Filters)
|
||||||
|
|
||||||
|
windowMinutes := rule.WindowMinutes
|
||||||
|
if windowMinutes <= 0 {
|
||||||
|
windowMinutes = 1
|
||||||
|
}
|
||||||
|
windowStart := safeEnd.Add(-time.Duration(windowMinutes) * time.Minute)
|
||||||
|
windowEnd := safeEnd
|
||||||
|
|
||||||
|
metricValue, ok := s.computeRuleMetric(ctx, rule, systemMetrics, windowStart, windowEnd, scopePlatform, scopeGroupID)
|
||||||
|
if !ok {
|
||||||
|
s.resetRuleState(rule.ID, now)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
|
||||||
|
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
|
||||||
|
consecutive := s.updateRuleBreaches(rule.ID, now, interval, breachedNow)
|
||||||
|
|
||||||
|
activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if breachedNow && consecutive >= required {
|
||||||
|
if activeEvent != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if latestEvent != nil && rule.CooldownMinutes > 0 {
|
||||||
|
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
|
||||||
|
if now.Sub(latestEvent.FiredAt) < cooldown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
firedEvent := &OpsAlertEvent{
|
||||||
|
RuleID: rule.ID,
|
||||||
|
Severity: strings.TrimSpace(rule.Severity),
|
||||||
|
Status: OpsAlertStatusFiring,
|
||||||
|
Title: fmt.Sprintf("%s: %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name)),
|
||||||
|
Description: buildOpsAlertDescription(rule, metricValue, windowMinutes, scopePlatform, scopeGroupID),
|
||||||
|
MetricValue: float64Ptr(metricValue),
|
||||||
|
ThresholdValue: float64Ptr(rule.Threshold),
|
||||||
|
Dimensions: buildOpsAlertDimensions(scopePlatform, scopeGroupID),
|
||||||
|
FiredAt: now,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if created != nil && created.ID > 0 {
|
||||||
|
s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not breached: resolve active event if present.
|
||||||
|
if activeEvent != nil {
|
||||||
|
resolvedAt := now
|
||||||
|
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
||||||
|
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
live := map[int64]struct{}{}
|
||||||
|
for _, r := range rules {
|
||||||
|
if r != nil && r.ID > 0 {
|
||||||
|
live[r.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id := range s.ruleStates {
|
||||||
|
if _, ok := live[id]; !ok {
|
||||||
|
delete(s.ruleStates, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) resetRuleState(ruleID int64, now time.Time) {
|
||||||
|
if ruleID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
state, ok := s.ruleStates[ruleID]
|
||||||
|
if !ok {
|
||||||
|
state = &opsAlertRuleState{}
|
||||||
|
s.ruleStates[ruleID] = state
|
||||||
|
}
|
||||||
|
state.LastEvaluatedAt = now
|
||||||
|
state.ConsecutiveBreaches = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) updateRuleBreaches(ruleID int64, now time.Time, interval time.Duration, breached bool) int {
|
||||||
|
if ruleID <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
state, ok := s.ruleStates[ruleID]
|
||||||
|
if !ok {
|
||||||
|
state = &opsAlertRuleState{}
|
||||||
|
s.ruleStates[ruleID] = state
|
||||||
|
}
|
||||||
|
|
||||||
|
if !state.LastEvaluatedAt.IsZero() && interval > 0 {
|
||||||
|
if now.Sub(state.LastEvaluatedAt) > interval*2 {
|
||||||
|
state.ConsecutiveBreaches = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state.LastEvaluatedAt = now
|
||||||
|
if breached {
|
||||||
|
state.ConsecutiveBreaches++
|
||||||
|
} else {
|
||||||
|
state.ConsecutiveBreaches = 0
|
||||||
|
}
|
||||||
|
return state.ConsecutiveBreaches
|
||||||
|
}
|
||||||
|
|
||||||
|
func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int {
|
||||||
|
if sustainedMinutes <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
return sustainedMinutes
|
||||||
|
}
|
||||||
|
required := int(math.Ceil(float64(sustainedMinutes*60) / interval.Seconds()))
|
||||||
|
if required < 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return required
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64) {
|
||||||
|
if filters == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if v, ok := filters["platform"]; ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
platform = strings.TrimSpace(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := filters["group_id"]; ok {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case float64:
|
||||||
|
if t > 0 {
|
||||||
|
id := int64(t)
|
||||||
|
groupID = &id
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if t > 0 {
|
||||||
|
id := t
|
||||||
|
groupID = &id
|
||||||
|
}
|
||||||
|
case int:
|
||||||
|
if t > 0 {
|
||||||
|
id := int64(t)
|
||||||
|
groupID = &id
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
n, err := strconv.ParseInt(strings.TrimSpace(t), 10, 64)
|
||||||
|
if err == nil && n > 0 {
|
||||||
|
groupID = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return platform, groupID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
||||||
|
ctx context.Context,
|
||||||
|
rule *OpsAlertRule,
|
||||||
|
systemMetrics *OpsSystemMetricsSnapshot,
|
||||||
|
start time.Time,
|
||||||
|
end time.Time,
|
||||||
|
platform string,
|
||||||
|
groupID *int64,
|
||||||
|
) (float64, bool) {
|
||||||
|
if rule == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
switch strings.TrimSpace(rule.MetricType) {
|
||||||
|
case "cpu_usage_percent":
|
||||||
|
if systemMetrics != nil && systemMetrics.CPUUsagePercent != nil {
|
||||||
|
return *systemMetrics.CPUUsagePercent, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
case "memory_usage_percent":
|
||||||
|
if systemMetrics != nil && systemMetrics.MemoryUsagePercent != nil {
|
||||||
|
return *systemMetrics.MemoryUsagePercent, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
case "concurrency_queue_depth":
|
||||||
|
if systemMetrics != nil && systemMetrics.ConcurrencyQueueDepth != nil {
|
||||||
|
return float64(*systemMetrics.ConcurrencyQueueDepth), true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
case "group_available_accounts":
|
||||||
|
if groupID == nil || *groupID <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if availability.Group == nil {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
return float64(availability.Group.AvailableCount), true
|
||||||
|
case "group_available_ratio":
|
||||||
|
if groupID == nil || *groupID <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return computeGroupAvailableRatio(availability.Group), true
|
||||||
|
case "account_rate_limited_count":
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.IsRateLimited
|
||||||
|
})), true
|
||||||
|
case "account_error_count":
|
||||||
|
if s == nil || s.opsService == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||||
|
if err != nil || availability == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||||
|
})), true
|
||||||
|
}
|
||||||
|
|
||||||
|
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
||||||
|
StartTime: start,
|
||||||
|
EndTime: end,
|
||||||
|
Platform: platform,
|
||||||
|
GroupID: groupID,
|
||||||
|
QueryMode: OpsQueryModeRaw,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if overview == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.TrimSpace(rule.MetricType) {
|
||||||
|
case "success_rate":
|
||||||
|
if overview.RequestCountSLA <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return overview.SLA * 100, true
|
||||||
|
case "error_rate":
|
||||||
|
if overview.RequestCountSLA <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return overview.ErrorRate * 100, true
|
||||||
|
case "upstream_error_rate":
|
||||||
|
if overview.RequestCountSLA <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return overview.UpstreamErrorRate * 100, true
|
||||||
|
case "p95_latency_ms":
|
||||||
|
if overview.Duration.P95 == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return float64(*overview.Duration.P95), true
|
||||||
|
case "p99_latency_ms":
|
||||||
|
if overview.Duration.P99 == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return float64(*overview.Duration.P99), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareMetric(value float64, operator string, threshold float64) bool {
|
||||||
|
switch strings.TrimSpace(operator) {
|
||||||
|
case ">":
|
||||||
|
return value > threshold
|
||||||
|
case ">=":
|
||||||
|
return value >= threshold
|
||||||
|
case "<":
|
||||||
|
return value < threshold
|
||||||
|
case "<=":
|
||||||
|
return value <= threshold
|
||||||
|
case "==":
|
||||||
|
return value == threshold
|
||||||
|
case "!=":
|
||||||
|
return value != threshold
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpsAlertDimensions(platform string, groupID *int64) map[string]any {
|
||||||
|
dims := map[string]any{}
|
||||||
|
if strings.TrimSpace(platform) != "" {
|
||||||
|
dims["platform"] = strings.TrimSpace(platform)
|
||||||
|
}
|
||||||
|
if groupID != nil && *groupID > 0 {
|
||||||
|
dims["group_id"] = *groupID
|
||||||
|
}
|
||||||
|
if len(dims) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return dims
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes int, platform string, groupID *int64) string {
|
||||||
|
if rule == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
scope := "overall"
|
||||||
|
if strings.TrimSpace(platform) != "" {
|
||||||
|
scope = fmt.Sprintf("platform=%s", strings.TrimSpace(platform))
|
||||||
|
}
|
||||||
|
if groupID != nil && *groupID > 0 {
|
||||||
|
scope = fmt.Sprintf("%s group_id=%d", scope, *groupID)
|
||||||
|
}
|
||||||
|
if windowMinutes <= 0 {
|
||||||
|
windowMinutes = 1
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s %s %.2f (current %.2f) over last %dm (%s)",
|
||||||
|
strings.TrimSpace(rule.MetricType),
|
||||||
|
strings.TrimSpace(rule.Operator),
|
||||||
|
rule.Threshold,
|
||||||
|
value,
|
||||||
|
windowMinutes,
|
||||||
|
strings.TrimSpace(scope),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) {
|
||||||
|
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if event.EmailSent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !rule.NotifyEmail {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
|
||||||
|
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(emailCfg.Alert.Recipients) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
|
||||||
|
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply/update rate limiter.
|
||||||
|
s.emailLimiter.SetLimit(emailCfg.Alert.RateLimitPerHour)
|
||||||
|
|
||||||
|
subject := fmt.Sprintf("[Ops Alert][%s] %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name))
|
||||||
|
body := buildOpsAlertEmailBody(rule, event)
|
||||||
|
|
||||||
|
anySent := false
|
||||||
|
for _, to := range emailCfg.Alert.Recipients {
|
||||||
|
addr := strings.TrimSpace(to)
|
||||||
|
if addr == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.emailLimiter.Allow(time.Now().UTC()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
|
||||||
|
// Ignore per-recipient failures; continue best-effort.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
anySent = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if anySent {
|
||||||
|
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
||||||
|
if rule == nil || event == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
metric := strings.TrimSpace(rule.MetricType)
|
||||||
|
value := "-"
|
||||||
|
threshold := fmt.Sprintf("%.2f", rule.Threshold)
|
||||||
|
if event.MetricValue != nil {
|
||||||
|
value = fmt.Sprintf("%.2f", *event.MetricValue)
|
||||||
|
}
|
||||||
|
if event.ThresholdValue != nil {
|
||||||
|
threshold = fmt.Sprintf("%.2f", *event.ThresholdValue)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(`
|
||||||
|
<h2>Ops Alert</h2>
|
||||||
|
<p><b>Rule</b>: %s</p>
|
||||||
|
<p><b>Severity</b>: %s</p>
|
||||||
|
<p><b>Status</b>: %s</p>
|
||||||
|
<p><b>Metric</b>: %s %s %s</p>
|
||||||
|
<p><b>Fired at</b>: %s</p>
|
||||||
|
<p><b>Description</b>: %s</p>
|
||||||
|
`,
|
||||||
|
htmlEscape(rule.Name),
|
||||||
|
htmlEscape(rule.Severity),
|
||||||
|
htmlEscape(event.Status),
|
||||||
|
htmlEscape(metric),
|
||||||
|
htmlEscape(rule.Operator),
|
||||||
|
htmlEscape(fmt.Sprintf("%s (threshold %s)", value, threshold)),
|
||||||
|
event.FiredAt.Format(time.RFC3339),
|
||||||
|
htmlEscape(event.Description),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldSendOpsAlertEmailByMinSeverity(minSeverity string, ruleSeverity string) bool {
|
||||||
|
minSeverity = strings.ToLower(strings.TrimSpace(minSeverity))
|
||||||
|
if minSeverity == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
eventLevel := opsEmailSeverityForOps(ruleSeverity)
|
||||||
|
minLevel := strings.ToLower(minSeverity)
|
||||||
|
|
||||||
|
rank := func(level string) int {
|
||||||
|
switch level {
|
||||||
|
case "critical":
|
||||||
|
return 3
|
||||||
|
case "warning":
|
||||||
|
return 2
|
||||||
|
case "info":
|
||||||
|
return 1
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rank(eventLevel) >= rank(minLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func opsEmailSeverityForOps(severity string) string {
|
||||||
|
switch strings.ToUpper(strings.TrimSpace(severity)) {
|
||||||
|
case "P0":
|
||||||
|
return "critical"
|
||||||
|
case "P1":
|
||||||
|
return "warning"
|
||||||
|
default:
|
||||||
|
return "info"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOpsAlertSilenced(now time.Time, rule *OpsAlertRule, event *OpsAlertEvent, silencing OpsAlertSilencingSettings) bool {
|
||||||
|
if !silencing.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if now.IsZero() {
|
||||||
|
now = time.Now().UTC()
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(silencing.GlobalUntilRFC3339) != "" {
|
||||||
|
if t, err := time.Parse(time.RFC3339, strings.TrimSpace(silencing.GlobalUntilRFC3339)); err == nil {
|
||||||
|
if now.Before(t) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range silencing.Entries {
|
||||||
|
untilRaw := strings.TrimSpace(entry.UntilRFC3339)
|
||||||
|
if untilRaw == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
until, err := time.Parse(time.RFC3339, untilRaw)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if now.After(until) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if entry.RuleID != nil && rule != nil && rule.ID > 0 && *entry.RuleID != rule.ID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(entry.Severities) > 0 {
|
||||||
|
match := false
|
||||||
|
for _, s := range entry.Severities {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(event.Severity)) || strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(rule.Severity)) {
|
||||||
|
match = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !match {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, lock OpsDistributedLockSettings) (func(), bool) {
|
||||||
|
if !lock.Enabled {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
if s.redisClient == nil {
|
||||||
|
s.warnNoRedisOnce.Do(func() {
|
||||||
|
log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock")
|
||||||
|
})
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(lock.Key)
|
||||||
|
if key == "" {
|
||||||
|
key = opsAlertEvaluatorLeaderLockKey
|
||||||
|
}
|
||||||
|
ttl := time.Duration(lock.TTLSeconds) * time.Second
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = opsAlertEvaluatorLeaderLockTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||||
|
if err != nil {
|
||||||
|
// Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky.
|
||||||
|
// Single-node deployments can disable the distributed lock via runtime settings.
|
||||||
|
s.warnNoRedisOnce.Do(func() {
|
||||||
|
log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err)
|
||||||
|
})
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
s.maybeLogSkip(key)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return func() {
|
||||||
|
_, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
|
||||||
|
s.skipLogMu.Lock()
|
||||||
|
defer s.skipLogMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < opsAlertEvaluatorSkipLogInterval {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.skipLogAt = now
|
||||||
|
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
|
||||||
|
if s == nil || s.opsRepo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
durMs := duration.Milliseconds()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||||
|
JobName: opsAlertEvaluatorJobName,
|
||||||
|
LastRunAt: &runAt,
|
||||||
|
LastSuccessAt: &now,
|
||||||
|
LastDurationMs: &durMs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpsAlertEvaluatorService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
|
||||||
|
if s == nil || s.opsRepo == nil || err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
durMs := duration.Milliseconds()
|
||||||
|
msg := truncateString(err.Error(), 2048)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||||
|
JobName: opsAlertEvaluatorJobName,
|
||||||
|
LastRunAt: &runAt,
|
||||||
|
LastErrorAt: &now,
|
||||||
|
LastError: &msg,
|
||||||
|
LastDurationMs: &durMs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func htmlEscape(s string) string {
|
||||||
|
replacer := strings.NewReplacer(
|
||||||
|
"&", "&",
|
||||||
|
"<", "<",
|
||||||
|
">", ">",
|
||||||
|
`"`, """,
|
||||||
|
"'", "'",
|
||||||
|
)
|
||||||
|
return replacer.Replace(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
type slidingWindowLimiter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
limit int
|
||||||
|
window time.Duration
|
||||||
|
sent []time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSlidingWindowLimiter(limit int, window time.Duration) *slidingWindowLimiter {
|
||||||
|
if window <= 0 {
|
||||||
|
window = time.Hour
|
||||||
|
}
|
||||||
|
return &slidingWindowLimiter{
|
||||||
|
limit: limit,
|
||||||
|
window: window,
|
||||||
|
sent: []time.Time{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *slidingWindowLimiter) SetLimit(limit int) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
l.limit = limit
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *slidingWindowLimiter) Allow(now time.Time) bool {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
if l.limit <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
cutoff := now.Add(-l.window)
|
||||||
|
keep := l.sent[:0]
|
||||||
|
for _, t := range l.sent {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
keep = append(keep, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.sent = keep
|
||||||
|
if len(l.sent) >= l.limit {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
l.sent = append(l.sent, now)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeGroupAvailableRatio returns the available percentage for a group.
|
||||||
|
// Formula: (AvailableCount / TotalAccounts) * 100.
|
||||||
|
// Returns 0 when TotalAccounts is 0.
|
||||||
|
func computeGroupAvailableRatio(group *GroupAvailability) float64 {
|
||||||
|
if group == nil || group.TotalAccounts <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return (float64(group.AvailableCount) / float64(group.TotalAccounts)) * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
// countAccountsByCondition counts accounts that satisfy the given condition.
|
||||||
|
func countAccountsByCondition(accounts map[int64]*AccountAvailability, condition func(*AccountAvailability) bool) int64 {
|
||||||
|
if len(accounts) == 0 || condition == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var count int64
|
||||||
|
for _, account := range accounts {
|
||||||
|
if account != nil && condition(account) {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
210
backend/internal/service/ops_alert_evaluator_service_test.go
Normal file
210
backend/internal/service/ops_alert_evaluator_service_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stubOpsRepo struct {
|
||||||
|
OpsRepository
|
||||||
|
overview *OpsDashboardOverview
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubOpsRepo) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
if s.overview != nil {
|
||||||
|
return s.overview, nil
|
||||||
|
}
|
||||||
|
return &OpsDashboardOverview{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeGroupAvailableRatio(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("正常情况: 10个账号, 8个可用 = 80%", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||||
|
TotalAccounts: 10,
|
||||||
|
AvailableCount: 8,
|
||||||
|
})
|
||||||
|
require.InDelta(t, 80.0, got, 0.0001)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("边界情况: TotalAccounts = 0 应返回 0", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||||
|
TotalAccounts: 0,
|
||||||
|
AvailableCount: 8,
|
||||||
|
})
|
||||||
|
require.Equal(t, 0.0, got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("边界情况: AvailableCount = 0 应返回 0%", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||||
|
TotalAccounts: 10,
|
||||||
|
AvailableCount: 0,
|
||||||
|
})
|
||||||
|
require.Equal(t, 0.0, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCountAccountsByCondition(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("测试限流账号统计: acc.IsRateLimited", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
accounts := map[int64]*AccountAvailability{
|
||||||
|
1: {IsRateLimited: true},
|
||||||
|
2: {IsRateLimited: false},
|
||||||
|
3: {IsRateLimited: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.IsRateLimited
|
||||||
|
})
|
||||||
|
require.Equal(t, int64(2), got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("测试错误账号统计(排除临时不可调度): acc.HasError && acc.TempUnschedulableUntil == nil", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
until := time.Now().UTC().Add(5 * time.Minute)
|
||||||
|
accounts := map[int64]*AccountAvailability{
|
||||||
|
1: {HasError: true},
|
||||||
|
2: {HasError: true, TempUnschedulableUntil: &until},
|
||||||
|
3: {HasError: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||||
|
})
|
||||||
|
require.Equal(t, int64(1), got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("边界情况: 空 map 应返回 0", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := countAccountsByCondition(map[int64]*AccountAvailability{}, func(acc *AccountAvailability) bool {
|
||||||
|
return acc.IsRateLimited
|
||||||
|
})
|
||||||
|
require.Equal(t, int64(0), got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeRuleMetricNewIndicators(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
groupID := int64(101)
|
||||||
|
platform := "openai"
|
||||||
|
|
||||||
|
availability := &OpsAccountAvailability{
|
||||||
|
Group: &GroupAvailability{
|
||||||
|
GroupID: groupID,
|
||||||
|
TotalAccounts: 10,
|
||||||
|
AvailableCount: 8,
|
||||||
|
},
|
||||||
|
Accounts: map[int64]*AccountAvailability{
|
||||||
|
1: {IsRateLimited: true},
|
||||||
|
2: {IsRateLimited: true},
|
||||||
|
3: {HasError: true},
|
||||||
|
4: {HasError: true, TempUnschedulableUntil: timePtr(time.Now().UTC().Add(2 * time.Minute))},
|
||||||
|
5: {HasError: false, IsRateLimited: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
opsService := &OpsService{
|
||||||
|
getAccountAvailability: func(_ context.Context, _ string, _ *int64) (*OpsAccountAvailability, error) {
|
||||||
|
return availability, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpsAlertEvaluatorService{
|
||||||
|
opsService: opsService,
|
||||||
|
opsRepo: &stubOpsRepo{overview: &OpsDashboardOverview{}},
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now().UTC().Add(-5 * time.Minute)
|
||||||
|
end := time.Now().UTC()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
metricType string
|
||||||
|
groupID *int64
|
||||||
|
wantValue float64
|
||||||
|
wantOK bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "group_available_accounts",
|
||||||
|
metricType: "group_available_accounts",
|
||||||
|
groupID: &groupID,
|
||||||
|
wantValue: 8,
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "group_available_ratio",
|
||||||
|
metricType: "group_available_ratio",
|
||||||
|
groupID: &groupID,
|
||||||
|
wantValue: 80.0,
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "account_rate_limited_count",
|
||||||
|
metricType: "account_rate_limited_count",
|
||||||
|
groupID: nil,
|
||||||
|
wantValue: 2,
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "account_error_count",
|
||||||
|
metricType: "account_error_count",
|
||||||
|
groupID: nil,
|
||||||
|
wantValue: 1,
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "group_available_accounts without group_id returns false",
|
||||||
|
metricType: "group_available_accounts",
|
||||||
|
groupID: nil,
|
||||||
|
wantValue: 0,
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "group_available_ratio without group_id returns false",
|
||||||
|
metricType: "group_available_ratio",
|
||||||
|
groupID: nil,
|
||||||
|
wantValue: 0,
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rule := &OpsAlertRule{
|
||||||
|
MetricType: tt.metricType,
|
||||||
|
}
|
||||||
|
gotValue, gotOK := svc.computeRuleMetric(ctx, rule, nil, start, end, platform, tt.groupID)
|
||||||
|
require.Equal(t, tt.wantOK, gotOK)
|
||||||
|
if !tt.wantOK {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.InDelta(t, tt.wantValue, gotValue, 0.0001)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user