merge: 合并上游 v0.1.86 到 main 分支

This commit is contained in:
erio
2026-02-25 19:02:10 +08:00
469 changed files with 65006 additions and 3674 deletions

View File

@@ -46,4 +46,4 @@ jobs:
with:
version: v2.7
args: --timeout=5m
working-directory: backend
working-directory: backend

4
.gitignore vendored
View File

@@ -122,7 +122,6 @@ AGENTS.md
scripts
.code-review-state
openspec/
docs/
code-reviews/
AGENTS.md
backend/cmd/server/server
@@ -139,3 +138,6 @@ tools/loadtest/
# Antigravity Manager
Antigravity-Manager/
antigravity_projectid_fix.patch
.codex/
frontend/coverage/
aicodex

View File

@@ -209,7 +209,30 @@ git add ent/ # 生成的文件也要提交
---
### 坑 10PR 提交前检查清单
### 坑 10前端测试看似正常,但后端调用失败(模型映射被批量误改)
**典型现象**
- 前端按钮点测看起来正常;
- 实际通过 API/客户端调用时返回 `Service temporarily unavailable` 或提示无可用账号;
- 常见于 OpenAI 账号(例如 Codex 模型)在批量修改后突然不可用。
**根因**
- OpenAI 账号编辑页默认不显式展示映射规则,容易让人误以为“没映射也没关系”;
- 但在**批量修改同时选中不同平台账号**OpenAI + Antigravity/Gemini模型白名单/映射可能被跨平台策略覆盖;
- 结果是 OpenAI 账号的关键模型映射丢失或被改坏,后端选不到可用账号。
**修复方案(按优先级)**
1. **快速修复(推荐)**:在批量修改中补回正确的透传映射(例如 `gpt-5.3-codex -> gpt-5.3-codex-spark`)。
2. **彻底重建**:删除并重新添加全部相关账号(最稳但成本高)。
**关键经验**
- 如果某模型已被软件内置默认映射覆盖,通常不需要额外再加透传;
- 但当上游模型更新快于本仓库默认映射时,**手动批量添加透传映射**是最简单、最低风险的临时兜底方案;
- 批量操作前尽量按平台分组,不要混选不同平台账号。
---
### 坑 11PR 提交前检查清单
提交 PR 前务必本地验证:

View File

@@ -36,7 +36,7 @@ RUN pnpm run build
FROM ${GOLANG_IMAGE} AS backend-builder
# Build arguments for version info (set by CI)
ARG VERSION=docker
ARG VERSION=
ARG COMMIT=docker
ARG DATE
ARG GOPROXY
@@ -61,9 +61,13 @@ COPY backend/ ./
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
# Build the binary (BuildType=release for CI builds, embed frontend)
RUN CGO_ENABLED=0 GOOS=linux go build \
# Version precedence: build arg VERSION > cmd/server/VERSION
RUN VERSION_VALUE="${VERSION}" && \
if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \
DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \
CGO_ENABLED=0 GOOS=linux go build \
-tags embed \
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
-ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \
-o /app/sub2api \
./cmd/server

View File

@@ -1,4 +1,4 @@
.PHONY: build build-backend build-frontend test test-backend test-frontend
.PHONY: build build-backend build-frontend test test-backend test-frontend secret-scan
# 一键编译前后端
build: build-backend build-frontend
@@ -20,3 +20,6 @@ test-backend:
test-frontend:
@pnpm --dir frontend run lint:check
@pnpm --dir frontend run typecheck
secret-scan:
@python3 tools/secret_scan.py

View File

@@ -363,6 +363,12 @@ default:
rate_multiplier: 1.0
```
### Sora Status (Temporarily Unavailable)
> ⚠️ Sora-related features are temporarily unavailable due to technical issues in upstream integration and media delivery.
> Please do not rely on Sora in production at this time.
> Existing `gateway.sora_*` configuration keys are reserved and may not take effect until these issues are resolved.
Additional security-related options are available in `config.yaml`:
- `cors.allowed_origins` for CORS allowlist

View File

@@ -139,6 +139,8 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。
#### 前置条件
- Docker 20.10+
@@ -370,6 +372,33 @@ default:
rate_multiplier: 1.0
```
### Sora 功能状态(暂不可用)
> ⚠️ 当前 Sora 相关功能因上游接入与媒体链路存在技术问题,暂时不可用。
> 现阶段请勿在生产环境依赖 Sora 能力。
> 文档中的 `gateway.sora_*` 配置仅作预留,待技术问题修复后再恢复可用。
### Sora 媒体签名 URL功能恢复后可选
当配置 `gateway.sora_media_signing_key``gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query
```yaml
gateway:
# /sora/media 是否强制要求 API Key默认 false
sora_media_require_api_key: false
# 媒体临时签名密钥(为空则禁用签名)
sora_media_signing_key: "your-signing-key"
# 临时签名 URL 有效期(秒)
sora_media_signed_url_ttl_seconds: 900
```
> 若未配置签名密钥,`/sora/media-signed` 将返回 503。
> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true仅允许携带 API Key 的 `/sora/media` 访问。
访问策略说明:
- `/sora/media`:内部调用或客户端携带 API Key 才能下载
- `/sora/media-signed`:外部可访问,但有签名 + 过期控制
`config.yaml` 还支持以下安全相关配置:
- `cors.allowed_origins` 配置 CORS 白名单
@@ -383,6 +412,14 @@ default:
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
- `turnstile.required` 在 release 模式强制启用 Turnstile
**网关防御纵深建议(重点)**
- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。
- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。
- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。
- `/auth/register``/auth/login``/auth/login/2fa``/auth/send-verify-code` 已提供服务端兜底限流Redis 故障时 fail-close
- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。
**⚠️ 安全警告HTTP URL 配置**
`security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL例如用于开发或内网测试必须显式设置
@@ -428,6 +465,29 @@ Invalid base URL: invalid url scheme: http
./sub2api
```
#### HTTP/2 (h2c) 与 HTTP/1.1 回退
后端明文端口默认支持 h2c并保留 HTTP/1.1 回退用于 WebSocket 与旧客户端。浏览器通常不支持 h2c性能收益主要在反向代理或内网链路。
**反向代理示例Caddy**
```caddyfile
transport http {
versions h2c h1
}
```
**验证:**
```bash
# h2c prior knowledge
curl --http2-prior-knowledge -I http://localhost:8080/health
# HTTP/1.1 回退
curl --http1.1 -I http://localhost:8080/health
# WebSocket 回退验证(需管理员 token
websocat -H="Sec-WebSocket-Protocol: sub2api-admin, jwt.<ADMIN_TOKEN>" ws://localhost:8080/api/v1/admin/ops/ws/qps
```
#### 开发模式
```bash

View File

@@ -14,4 +14,7 @@ test-integration:
go test -tags=integration ./...
test-e2e:
go test -tags=e2e ./...
./scripts/e2e-test.sh
test-e2e-local:
go test -tags=e2e -v -timeout=300s ./internal/integration/...

View File

@@ -17,7 +17,7 @@ func main() {
email := flag.String("email", "", "Admin email to issue a JWT for (defaults to first active admin)")
flag.Parse()
cfg, err := config.Load()
cfg, err := config.LoadForBootstrap()
if err != nil {
log.Fatalf("failed to load config: %v", err)
}

View File

@@ -1 +1 @@
0.1.84.8
0.1.86.1

View File

@@ -8,7 +8,6 @@ import (
"errors"
"flag"
"log"
"log/slog"
"net/http"
"os"
"os/signal"
@@ -19,11 +18,14 @@ import (
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
//go:embed VERSION
@@ -38,7 +40,12 @@ var (
)
func init() {
// Read version from embedded VERSION file
// 如果 Version 已通过 ldflags 注入(例如 -X main.Version=...),则不要覆盖。
if strings.TrimSpace(Version) != "" {
return
}
// 默认从 embedded VERSION 文件读取版本号(编译期打包进二进制)。
Version = strings.TrimSpace(embeddedVersion)
if Version == "" {
Version = "0.0.0-dev"
@@ -47,22 +54,9 @@ func init() {
// initLogger configures the default slog handler based on gin.Mode().
// In non-release mode, Debug level logs are enabled.
func initLogger() {
var level slog.Level
if gin.Mode() == gin.ReleaseMode {
level = slog.LevelInfo
} else {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
slog.SetDefault(slog.New(handler))
}
func main() {
// Initialize slog logger based on gin mode
initLogger()
logger.InitBootstrap()
defer logger.Sync()
// Parse command line flags
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
@@ -122,16 +116,26 @@ func runSetupServer() {
log.Printf("Setup wizard available at http://%s", addr)
log.Println("Complete the setup wizard to configure Sub2API")
if err := r.Run(addr); err != nil {
server := &http.Server{
Addr: addr,
Handler: h2c.NewHandler(r, &http2.Server{}),
ReadHeaderTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start setup server: %v", err)
}
}
func runMainServer() {
cfg, err := config.Load()
cfg, err := config.LoadForBootstrap()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
if err := logger.Init(logger.OptionsFromConfig(cfg.Log)); err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
if cfg.RunMode == config.RunModeSimple {
log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED")
}

View File

@@ -67,14 +67,19 @@ func provideCleanup(
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
idempotencyCleanup *service.IdempotencyCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
subscriptionService *service.SubscriptionService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
@@ -101,6 +106,18 @@ func provideCleanup(
}
return nil
}},
{"OpsSystemLogSink", func() error {
if opsSystemLogSink != nil {
opsSystemLogSink.Stop()
}
return nil
}},
{"SoraMediaCleanupService", func() error {
if soraMediaCleanup != nil {
soraMediaCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
@@ -131,6 +148,12 @@ func provideCleanup(
}
return nil
}},
{"IdempotencyCleanupService", func() error {
if idempotencyCleanup != nil {
idempotencyCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
@@ -143,6 +166,12 @@ func provideCleanup(
subscriptionExpiry.Stop()
return nil
}},
{"SubscriptionService", func() error {
if subscriptionService != nil {
subscriptionService.Stop()
}
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil
@@ -155,6 +184,12 @@ func provideCleanup(
billingCache.Stop()
return nil
}},
{"UsageRecordWorkerPool", func() error {
if usageRecordWorkerPool != nil {
usageRecordWorkerPool.Stop()
}
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil

View File

@@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
@@ -98,10 +98,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
soraAccountRepository := repository.NewSoraAccountRepository(db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -159,14 +160,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
idempotencyRepository := repository.NewIdempotencyRepository(client, db)
systemOperationLockService := service.ProvideSystemOperationLockService(idempotencyRepository, configConfig)
systemHandler := handler.ProvideSystemHandler(updateService, systemOperationLockService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
@@ -180,11 +184,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -195,10 +206,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -228,14 +240,19 @@ func provideCleanup(
opsAlertEvaluator *service.OpsAlertEvaluatorService,
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
idempotencyCleanup *service.IdempotencyCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
subscriptionService *service.SubscriptionService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
@@ -261,6 +278,18 @@ func provideCleanup(
}
return nil
}},
{"OpsSystemLogSink", func() error {
if opsSystemLogSink != nil {
opsSystemLogSink.Stop()
}
return nil
}},
{"SoraMediaCleanupService", func() error {
if soraMediaCleanup != nil {
soraMediaCleanup.Stop()
}
return nil
}},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
@@ -291,6 +320,12 @@ func provideCleanup(
}
return nil
}},
{"IdempotencyCleanupService", func() error {
if idempotencyCleanup != nil {
idempotencyCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
@@ -303,6 +338,12 @@ func provideCleanup(
subscriptionExpiry.Stop()
return nil
}},
{"SubscriptionService", func() error {
if subscriptionService != nil {
subscriptionService.Stop()
}
return nil
}},
{"PricingService", func() error {
pricing.Stop()
return nil
@@ -315,6 +356,12 @@ func provideCleanup(
billingCache.Stop()
return nil
}},
{"UsageRecordWorkerPool", func() error {
if usageRecordWorkerPool != nil {
usageRecordWorkerPool.Stop()
}
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil

View File

@@ -36,6 +36,8 @@ type APIKey struct {
GroupID *int64 `json:"group_id,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// Last usage time of this API key
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
// Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]
IPWhitelist []string `json:"ip_whitelist,omitempty"`
// Blocked IPs/CIDRs
@@ -109,7 +111,7 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
values[i] = new(sql.NullString)
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt:
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -182,6 +184,13 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Status = value.String
}
case apikey.FieldLastUsedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_used_at", values[i])
} else if value.Valid {
_m.LastUsedAt = new(time.Time)
*_m.LastUsedAt = value.Time
}
case apikey.FieldIPWhitelist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i])
@@ -296,6 +305,11 @@ func (_m *APIKey) String() string {
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
if v := _m.LastUsedAt; v != nil {
builder.WriteString("last_used_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("ip_whitelist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist))
builder.WriteString(", ")

View File

@@ -31,6 +31,8 @@ const (
FieldGroupID = "group_id"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldLastUsedAt holds the string denoting the last_used_at field in the database.
FieldLastUsedAt = "last_used_at"
// FieldIPWhitelist holds the string denoting the ip_whitelist field in the database.
FieldIPWhitelist = "ip_whitelist"
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
@@ -83,6 +85,7 @@ var Columns = []string{
FieldName,
FieldGroupID,
FieldStatus,
FieldLastUsedAt,
FieldIPWhitelist,
FieldIPBlacklist,
FieldQuota,
@@ -176,6 +179,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
}
// ByLastUsedAt orders the results by the last_used_at field.
func ByLastUsedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastUsedAt, opts...).ToFunc()
}
// ByQuota orders the results by the quota field.
func ByQuota(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldQuota, opts...).ToFunc()

View File

@@ -95,6 +95,11 @@ func Status(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldStatus, v))
}
// LastUsedAt applies equality check predicate on the "last_used_at" field. It's identical to LastUsedAtEQ.
func LastUsedAt(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v))
}
// Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ.
func Quota(v float64) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldQuota, v))
@@ -485,6 +490,56 @@ func StatusContainsFold(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
}
// LastUsedAtEQ applies the EQ predicate on the "last_used_at" field.
func LastUsedAtEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v))
}
// LastUsedAtNEQ applies the NEQ predicate on the "last_used_at" field.
func LastUsedAtNEQ(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNEQ(FieldLastUsedAt, v))
}
// LastUsedAtIn applies the In predicate on the "last_used_at" field.
func LastUsedAtIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldIn(FieldLastUsedAt, vs...))
}
// LastUsedAtNotIn applies the NotIn predicate on the "last_used_at" field.
func LastUsedAtNotIn(vs ...time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldNotIn(FieldLastUsedAt, vs...))
}
// LastUsedAtGT applies the GT predicate on the "last_used_at" field.
func LastUsedAtGT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGT(FieldLastUsedAt, v))
}
// LastUsedAtGTE applies the GTE predicate on the "last_used_at" field.
func LastUsedAtGTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldGTE(FieldLastUsedAt, v))
}
// LastUsedAtLT applies the LT predicate on the "last_used_at" field.
func LastUsedAtLT(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLT(FieldLastUsedAt, v))
}
// LastUsedAtLTE applies the LTE predicate on the "last_used_at" field.
func LastUsedAtLTE(v time.Time) predicate.APIKey {
return predicate.APIKey(sql.FieldLTE(FieldLastUsedAt, v))
}
// LastUsedAtIsNil applies the IsNil predicate on the "last_used_at" field.
func LastUsedAtIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldLastUsedAt))
}
// LastUsedAtNotNil applies the NotNil predicate on the "last_used_at" field.
func LastUsedAtNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldLastUsedAt))
}
// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field.
func IPWhitelistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist))

View File

@@ -113,6 +113,20 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate {
return _c
}
// SetLastUsedAt sets the "last_used_at" field.
func (_c *APIKeyCreate) SetLastUsedAt(v time.Time) *APIKeyCreate {
_c.mutation.SetLastUsedAt(v)
return _c
}
// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil.
func (_c *APIKeyCreate) SetNillableLastUsedAt(v *time.Time) *APIKeyCreate {
if v != nil {
_c.SetLastUsedAt(*v)
}
return _c
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate {
_c.mutation.SetIPWhitelist(v)
@@ -353,6 +367,10 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
_node.Status = value
}
if value, ok := _c.mutation.LastUsedAt(); ok {
_spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value)
_node.LastUsedAt = &value
}
if value, ok := _c.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
_node.IPWhitelist = value
@@ -571,6 +589,24 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert {
return u
}
// SetLastUsedAt sets the "last_used_at" field.
func (u *APIKeyUpsert) SetLastUsedAt(v time.Time) *APIKeyUpsert {
u.Set(apikey.FieldLastUsedAt, v)
return u
}
// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateLastUsedAt() *APIKeyUpsert {
u.SetExcluded(apikey.FieldLastUsedAt)
return u
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (u *APIKeyUpsert) ClearLastUsedAt() *APIKeyUpsert {
u.SetNull(apikey.FieldLastUsedAt)
return u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPWhitelist, v)
@@ -818,6 +854,27 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne {
})
}
// SetLastUsedAt sets the "last_used_at" field.
func (u *APIKeyUpsertOne) SetLastUsedAt(v time.Time) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetLastUsedAt(v)
})
}
// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateLastUsedAt() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateLastUsedAt()
})
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (u *APIKeyUpsertOne) ClearLastUsedAt() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearLastUsedAt()
})
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
@@ -1246,6 +1303,27 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk {
})
}
// SetLastUsedAt sets the "last_used_at" field.
func (u *APIKeyUpsertBulk) SetLastUsedAt(v time.Time) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetLastUsedAt(v)
})
}
// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateLastUsedAt() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateLastUsedAt()
})
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (u *APIKeyUpsertBulk) ClearLastUsedAt() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearLastUsedAt()
})
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {

View File

@@ -134,6 +134,26 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate {
return _u
}
// SetLastUsedAt sets the "last_used_at" field.
func (_u *APIKeyUpdate) SetLastUsedAt(v time.Time) *APIKeyUpdate {
_u.mutation.SetLastUsedAt(v)
return _u
}
// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil.
func (_u *APIKeyUpdate) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdate {
if v != nil {
_u.SetLastUsedAt(*v)
}
return _u
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (_u *APIKeyUpdate) ClearLastUsedAt() *APIKeyUpdate {
_u.mutation.ClearLastUsedAt()
return _u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate {
_u.mutation.SetIPWhitelist(v)
@@ -390,6 +410,12 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.LastUsedAt(); ok {
_spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value)
}
if _u.mutation.LastUsedAtCleared() {
_spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime)
}
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}
@@ -655,6 +681,26 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne {
return _u
}
// SetLastUsedAt sets the "last_used_at" field.
func (_u *APIKeyUpdateOne) SetLastUsedAt(v time.Time) *APIKeyUpdateOne {
_u.mutation.SetLastUsedAt(v)
return _u
}
// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil.
func (_u *APIKeyUpdateOne) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdateOne {
if v != nil {
_u.SetLastUsedAt(*v)
}
return _u
}
// ClearLastUsedAt clears the value of the "last_used_at" field.
func (_u *APIKeyUpdateOne) ClearLastUsedAt() *APIKeyUpdateOne {
_u.mutation.ClearLastUsedAt()
return _u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne {
_u.mutation.SetIPWhitelist(v)
@@ -941,6 +987,12 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.LastUsedAt(); ok {
_spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value)
}
if _u.mutation.LastUsedAtCleared() {
_spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime)
}
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -65,6 +66,8 @@ type Client struct {
Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient
// SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
@@ -103,6 +106,7 @@ func (c *Client) init() {
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
c.RedeemCode = NewRedeemCodeClient(c.config)
c.SecuritySecret = NewSecuritySecretClient(c.config)
c.Setting = NewSettingClient(c.config)
c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config)
c.UsageLog = NewUsageLogClient(c.config)
@@ -214,6 +218,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
SecuritySecret: NewSecuritySecretClient(cfg),
Setting: NewSettingClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
@@ -252,6 +257,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
SecuritySecret: NewSecuritySecretClient(cfg),
Setting: NewSettingClient(cfg),
UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
@@ -291,8 +297,8 @@ func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -305,8 +311,8 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -338,6 +344,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Proxy.mutate(ctx, m)
case *RedeemCodeMutation:
return c.RedeemCode.mutate(ctx, m)
case *SecuritySecretMutation:
return c.SecuritySecret.mutate(ctx, m)
case *SettingMutation:
return c.Setting.mutate(ctx, m)
case *UsageCleanupTaskMutation:
@@ -2197,6 +2205,139 @@ func (c *RedeemCodeClient) mutate(ctx context.Context, m *RedeemCodeMutation) (V
}
}
// SecuritySecretClient is a client for the SecuritySecret schema.
type SecuritySecretClient struct {
config
}
// NewSecuritySecretClient returns a client for the SecuritySecret from the given config.
func NewSecuritySecretClient(c config) *SecuritySecretClient {
return &SecuritySecretClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `securitysecret.Hooks(f(g(h())))`.
func (c *SecuritySecretClient) Use(hooks ...Hook) {
c.hooks.SecuritySecret = append(c.hooks.SecuritySecret, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `securitysecret.Intercept(f(g(h())))`.
func (c *SecuritySecretClient) Intercept(interceptors ...Interceptor) {
c.inters.SecuritySecret = append(c.inters.SecuritySecret, interceptors...)
}
// Create returns a builder for creating a SecuritySecret entity.
func (c *SecuritySecretClient) Create() *SecuritySecretCreate {
mutation := newSecuritySecretMutation(c.config, OpCreate)
return &SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of SecuritySecret entities.
func (c *SecuritySecretClient) CreateBulk(builders ...*SecuritySecretCreate) *SecuritySecretCreateBulk {
return &SecuritySecretCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *SecuritySecretClient) MapCreateBulk(slice any, setFunc func(*SecuritySecretCreate, int)) *SecuritySecretCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &SecuritySecretCreateBulk{err: fmt.Errorf("calling to SecuritySecretClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*SecuritySecretCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &SecuritySecretCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for SecuritySecret.
func (c *SecuritySecretClient) Update() *SecuritySecretUpdate {
mutation := newSecuritySecretMutation(c.config, OpUpdate)
return &SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *SecuritySecretClient) UpdateOne(_m *SecuritySecret) *SecuritySecretUpdateOne {
mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecret(_m))
return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *SecuritySecretClient) UpdateOneID(id int64) *SecuritySecretUpdateOne {
mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecretID(id))
return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for SecuritySecret.
func (c *SecuritySecretClient) Delete() *SecuritySecretDelete {
mutation := newSecuritySecretMutation(c.config, OpDelete)
return &SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *SecuritySecretClient) DeleteOne(_m *SecuritySecret) *SecuritySecretDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *SecuritySecretClient) DeleteOneID(id int64) *SecuritySecretDeleteOne {
builder := c.Delete().Where(securitysecret.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &SecuritySecretDeleteOne{builder}
}
// Query returns a query builder for SecuritySecret.
func (c *SecuritySecretClient) Query() *SecuritySecretQuery {
return &SecuritySecretQuery{
config: c.config,
ctx: &QueryContext{Type: TypeSecuritySecret},
inters: c.Interceptors(),
}
}
// Get returns a SecuritySecret entity by its id.
func (c *SecuritySecretClient) Get(ctx context.Context, id int64) (*SecuritySecret, error) {
return c.Query().Where(securitysecret.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *SecuritySecretClient) GetX(ctx context.Context, id int64) *SecuritySecret {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *SecuritySecretClient) Hooks() []Hook {
return c.hooks.SecuritySecret
}
// Interceptors returns the client interceptors.
func (c *SecuritySecretClient) Interceptors() []Interceptor {
return c.inters.SecuritySecret
}
func (c *SecuritySecretClient) mutate(ctx context.Context, m *SecuritySecretMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown SecuritySecret mutation op: %q", m.Op())
}
}
// SettingClient is a client for the Setting schema.
type SettingClient struct {
config
@@ -3607,13 +3748,13 @@ type (
hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
}
)

View File

@@ -23,6 +23,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -102,6 +103,7 @@ func checkColumn(t, c string) error {
promocodeusage.Table: promocodeusage.ValidColumn,
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
securitysecret.Table: securitysecret.ValidColumn,
setting.Table: setting.ValidColumn,
usagecleanuptask.Table: usagecleanuptask.ValidColumn,
usagelog.Table: usagelog.ValidColumn,

View File

@@ -52,6 +52,14 @@ type Group struct {
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
// ImagePrice4k holds the value of the "image_price_4k" field.
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
// SoraImagePrice360 holds the value of the "sora_image_price_360" field.
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
// SoraImagePrice540 holds the value of the "sora_image_price_540" field.
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
// SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field.
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
// SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
@@ -178,7 +186,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
values[i] = new(sql.NullInt64)
@@ -317,6 +325,34 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.ImagePrice4k = new(float64)
*_m.ImagePrice4k = value.Float64
}
case group.FieldSoraImagePrice360:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i])
} else if value.Valid {
_m.SoraImagePrice360 = new(float64)
*_m.SoraImagePrice360 = value.Float64
}
case group.FieldSoraImagePrice540:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i])
} else if value.Valid {
_m.SoraImagePrice540 = new(float64)
*_m.SoraImagePrice540 = value.Float64
}
case group.FieldSoraVideoPricePerRequest:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i])
} else if value.Valid {
_m.SoraVideoPricePerRequest = new(float64)
*_m.SoraVideoPricePerRequest = value.Float64
}
case group.FieldSoraVideoPricePerRequestHd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i])
} else if value.Valid {
_m.SoraVideoPricePerRequestHd = new(float64)
*_m.SoraVideoPricePerRequestHd = value.Float64
}
case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
@@ -514,6 +550,26 @@ func (_m *Group) String() string {
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraImagePrice360; v != nil {
builder.WriteString("sora_image_price_360=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraImagePrice540; v != nil {
builder.WriteString("sora_image_price_540=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraVideoPricePerRequest; v != nil {
builder.WriteString("sora_video_price_per_request=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.SoraVideoPricePerRequestHd; v != nil {
builder.WriteString("sora_video_price_per_request_hd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ")

View File

@@ -49,6 +49,14 @@ const (
FieldImagePrice2k = "image_price_2k"
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
FieldImagePrice4k = "image_price_4k"
// FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database.
FieldSoraImagePrice360 = "sora_image_price_360"
// FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database.
FieldSoraImagePrice540 = "sora_image_price_540"
// FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database.
FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
// FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
@@ -157,6 +165,10 @@ var Columns = []string{
FieldImagePrice1k,
FieldImagePrice2k,
FieldImagePrice4k,
FieldSoraImagePrice360,
FieldSoraImagePrice540,
FieldSoraVideoPricePerRequest,
FieldSoraVideoPricePerRequestHd,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
FieldFallbackGroupIDOnInvalidRequest,
@@ -325,6 +337,26 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
}
// BySoraImagePrice360 orders the results by the sora_image_price_360 field.
func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc()
}
// BySoraImagePrice540 orders the results by the sora_image_price_540 field.
func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc()
}
// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field.
func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc()
}
// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field.
func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
}
// ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()

View File

@@ -140,6 +140,26 @@ func ImagePrice4k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
}
// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ.
func SoraImagePrice360(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ.
func SoraImagePrice540(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
}
// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ.
func SoraVideoPricePerRequest(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ.
func SoraVideoPricePerRequestHd(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
}
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
@@ -1025,6 +1045,206 @@ func ImagePrice4kNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
}
// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field.
func SoraImagePrice360EQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field.
func SoraImagePrice360NEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v))
}
// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field.
func SoraImagePrice360In(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...))
}
// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field.
func SoraImagePrice360NotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...))
}
// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field.
func SoraImagePrice360GT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v))
}
// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field.
func SoraImagePrice360GTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v))
}
// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field.
func SoraImagePrice360LT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v))
}
// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field.
func SoraImagePrice360LTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v))
}
// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field.
func SoraImagePrice360IsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360))
}
// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field.
func SoraImagePrice360NotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360))
}
// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field.
func SoraImagePrice540EQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
}
// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field.
func SoraImagePrice540NEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v))
}
// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field.
func SoraImagePrice540In(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...))
}
// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field.
func SoraImagePrice540NotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...))
}
// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field.
func SoraImagePrice540GT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v))
}
// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field.
func SoraImagePrice540GTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v))
}
// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field.
func SoraImagePrice540LT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v))
}
// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field.
func SoraImagePrice540LTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v))
}
// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field.
func SoraImagePrice540IsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540))
}
// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field.
func SoraImagePrice540NotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540))
}
// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...))
}
// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...))
}
// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestGT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestGTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestLT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestLTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v))
}
// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest))
}
// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field.
func SoraVideoPricePerRequestNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest))
}
// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...))
}
// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group {
return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...))
}
// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdGT(v float64) predicate.Group {
return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group {
return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdLT(v float64) predicate.Group {
return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v))
}
// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdIsNil() predicate.Group {
return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd))
}
// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field.
func SoraVideoPricePerRequestHdNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
}
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))

View File

@@ -258,6 +258,62 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
return _c
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate {
_c.mutation.SetSoraImagePrice360(v)
return _c
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraImagePrice360(*v)
}
return _c
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate {
_c.mutation.SetSoraImagePrice540(v)
return _c
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraImagePrice540(*v)
}
return _c
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate {
_c.mutation.SetSoraVideoPricePerRequest(v)
return _c
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraVideoPricePerRequest(*v)
}
return _c
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate {
_c.mutation.SetSoraVideoPricePerRequestHd(v)
return _c
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate {
if v != nil {
_c.SetSoraVideoPricePerRequestHd(*v)
}
return _c
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v)
@@ -701,6 +757,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
_node.ImagePrice4k = &value
}
if value, ok := _c.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
_node.SoraImagePrice360 = &value
}
if value, ok := _c.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
_node.SoraImagePrice540 = &value
}
if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
_node.SoraVideoPricePerRequest = &value
}
if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
_node.SoraVideoPricePerRequestHd = &value
}
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value
@@ -1177,6 +1249,102 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
return u
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert {
u.Set(group.FieldSoraImagePrice360, v)
return u
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert {
u.SetExcluded(group.FieldSoraImagePrice360)
return u
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert {
u.Add(group.FieldSoraImagePrice360, v)
return u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert {
u.SetNull(group.FieldSoraImagePrice360)
return u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert {
u.Set(group.FieldSoraImagePrice540, v)
return u
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert {
u.SetExcluded(group.FieldSoraImagePrice540)
return u
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert {
u.Add(group.FieldSoraImagePrice540, v)
return u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert {
u.SetNull(group.FieldSoraImagePrice540)
return u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert {
u.Set(group.FieldSoraVideoPricePerRequest, v)
return u
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert {
u.SetExcluded(group.FieldSoraVideoPricePerRequest)
return u
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert {
u.Add(group.FieldSoraVideoPricePerRequest, v)
return u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert {
u.SetNull(group.FieldSoraVideoPricePerRequest)
return u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
u.Set(group.FieldSoraVideoPricePerRequestHd, v)
return u
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert {
u.SetExcluded(group.FieldSoraVideoPricePerRequestHd)
return u
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
u.Add(group.FieldSoraVideoPricePerRequestHd, v)
return u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
u.SetNull(group.FieldSoraVideoPricePerRequestHd)
return u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v)
@@ -1690,6 +1858,118 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
})
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice360(v)
})
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice360(v)
})
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice360()
})
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice360()
})
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice540(v)
})
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice540(v)
})
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice540()
})
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice540()
})
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequest(v)
})
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequest(v)
})
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequest()
})
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequest()
})
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequestHd(v)
})
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequestHd(v)
})
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequestHd()
})
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequestHd()
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@@ -2391,6 +2671,118 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
})
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice360(v)
})
}
// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice360(v)
})
}
// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice360()
})
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice360()
})
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraImagePrice540(v)
})
}
// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraImagePrice540(v)
})
}
// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraImagePrice540()
})
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraImagePrice540()
})
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequest(v)
})
}
// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequest(v)
})
}
// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequest()
})
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequest()
})
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSoraVideoPricePerRequestHd(v)
})
}
// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.AddSoraVideoPricePerRequestHd(v)
})
}
// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSoraVideoPricePerRequestHd()
})
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.ClearSoraVideoPricePerRequestHd()
})
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {

View File

@@ -355,6 +355,114 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
return _u
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate {
_u.mutation.ResetSoraImagePrice360()
_u.mutation.SetSoraImagePrice360(v)
return _u
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraImagePrice360(*v)
}
return _u
}
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate {
_u.mutation.AddSoraImagePrice360(v)
return _u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate {
_u.mutation.ClearSoraImagePrice360()
return _u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate {
_u.mutation.ResetSoraImagePrice540()
_u.mutation.SetSoraImagePrice540(v)
return _u
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraImagePrice540(*v)
}
return _u
}
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate {
_u.mutation.AddSoraImagePrice540(v)
return _u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate {
_u.mutation.ClearSoraImagePrice540()
return _u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate {
_u.mutation.ResetSoraVideoPricePerRequest()
_u.mutation.SetSoraVideoPricePerRequest(v)
return _u
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraVideoPricePerRequest(*v)
}
return _u
}
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate {
_u.mutation.AddSoraVideoPricePerRequest(v)
return _u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate {
_u.mutation.ClearSoraVideoPricePerRequest()
return _u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
_u.mutation.ResetSoraVideoPricePerRequestHd()
_u.mutation.SetSoraVideoPricePerRequestHd(v)
return _u
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate {
if v != nil {
_u.SetSoraVideoPricePerRequestHd(*v)
}
return _u
}
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
_u.mutation.AddSoraVideoPricePerRequestHd(v)
return _u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
_u.mutation.ClearSoraVideoPricePerRequestHd()
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v)
@@ -892,6 +1000,42 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice360Cleared() {
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice540Cleared() {
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
@@ -1573,6 +1717,114 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
return _u
}
// SetSoraImagePrice360 sets the "sora_image_price_360" field.
func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraImagePrice360()
_u.mutation.SetSoraImagePrice360(v)
return _u
}
// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraImagePrice360(*v)
}
return _u
}
// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne {
_u.mutation.AddSoraImagePrice360(v)
return _u
}
// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne {
_u.mutation.ClearSoraImagePrice360()
return _u
}
// SetSoraImagePrice540 sets the "sora_image_price_540" field.
func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraImagePrice540()
_u.mutation.SetSoraImagePrice540(v)
return _u
}
// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraImagePrice540(*v)
}
return _u
}
// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne {
_u.mutation.AddSoraImagePrice540(v)
return _u
}
// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne {
_u.mutation.ClearSoraImagePrice540()
return _u
}
// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraVideoPricePerRequest()
_u.mutation.SetSoraVideoPricePerRequest(v)
return _u
}
// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraVideoPricePerRequest(*v)
}
return _u
}
// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
_u.mutation.AddSoraVideoPricePerRequest(v)
return _u
}
// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne {
_u.mutation.ClearSoraVideoPricePerRequest()
return _u
}
// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
_u.mutation.ResetSoraVideoPricePerRequestHd()
_u.mutation.SetSoraVideoPricePerRequestHd(v)
return _u
}
// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne {
if v != nil {
_u.SetSoraVideoPricePerRequestHd(*v)
}
return _u
}
// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
_u.mutation.AddSoraVideoPricePerRequestHd(v)
return _u
}
// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
_u.mutation.ClearSoraVideoPricePerRequestHd()
return _u
}
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v)
@@ -2140,6 +2392,42 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice360(); ok {
_spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
_spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice360Cleared() {
_spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraImagePrice540(); ok {
_spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
_spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
}
if _u.mutation.SoraImagePrice540Cleared() {
_spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
}
if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
_spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
}
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
}
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}

View File

@@ -141,6 +141,18 @@ func (f RedeemCodeFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value,
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.RedeemCodeMutation", m)
}
// The SecuritySecretFunc type is an adapter to allow the use of ordinary
// function as SecuritySecret mutator.
type SecuritySecretFunc func(context.Context, *ent.SecuritySecretMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f SecuritySecretFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.SecuritySecretMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SecuritySecretMutation", m)
}
// The SettingFunc type is an adapter to allow the use of ordinary
// function as Setting mutator.
type SettingFunc func(context.Context, *ent.SettingMutation) (ent.Value, error)

View File

@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -383,6 +384,33 @@ func (f TraverseRedeemCode) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.RedeemCodeQuery", q)
}
// The SecuritySecretFunc type is an adapter to allow the use of ordinary function as a Querier.
type SecuritySecretFunc func(context.Context, *ent.SecuritySecretQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f SecuritySecretFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.SecuritySecretQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q)
}
// The TraverseSecuritySecret type is an adapter to allow the use of ordinary function as Traverser.
type TraverseSecuritySecret func(context.Context, *ent.SecuritySecretQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseSecuritySecret) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseSecuritySecret) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.SecuritySecretQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q)
}
// The SettingFunc type is an adapter to allow the use of ordinary function as a Querier.
type SettingFunc func(context.Context, *ent.SettingQuery) (ent.Value, error)
@@ -624,6 +652,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.ProxyQuery, predicate.Proxy, proxy.OrderOption]{typ: ent.TypeProxy, tq: q}, nil
case *ent.RedeemCodeQuery:
return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil
case *ent.SecuritySecretQuery:
return &query[*ent.SecuritySecretQuery, predicate.SecuritySecret, securitysecret.OrderOption]{typ: ent.TypeSecuritySecret, tq: q}, nil
case *ent.SettingQuery:
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
case *ent.UsageCleanupTaskQuery:

View File

@@ -18,6 +18,7 @@ var (
{Name: "key", Type: field.TypeString, Unique: true, Size: 128},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "last_used_at", Type: field.TypeTime, Nullable: true},
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
{Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@@ -34,13 +35,13 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "api_keys_groups_api_keys",
Columns: []*schema.Column{APIKeysColumns[12]},
Columns: []*schema.Column{APIKeysColumns[13]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "api_keys_users_api_keys",
Columns: []*schema.Column{APIKeysColumns[13]},
Columns: []*schema.Column{APIKeysColumns[14]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -49,12 +50,12 @@ var (
{
Name: "apikey_user_id",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[13]},
Columns: []*schema.Column{APIKeysColumns[14]},
},
{
Name: "apikey_group_id",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[12]},
Columns: []*schema.Column{APIKeysColumns[13]},
},
{
Name: "apikey_status",
@@ -66,15 +67,20 @@ var (
Unique: false,
Columns: []*schema.Column{APIKeysColumns[3]},
},
{
Name: "apikey_last_used_at",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[7]},
},
{
Name: "apikey_quota_quota_used",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]},
Columns: []*schema.Column{APIKeysColumns[10], APIKeysColumns[11]},
},
{
Name: "apikey_expires_at",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[11]},
Columns: []*schema.Column{APIKeysColumns[12]},
},
},
}
@@ -366,6 +372,10 @@ var (
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
@@ -409,7 +419,7 @@ var (
{
Name: "group_sort_order",
Unique: false,
Columns: []*schema.Column{GroupsColumns[25]},
Columns: []*schema.Column{GroupsColumns[29]},
},
},
}
@@ -572,6 +582,20 @@ var (
},
},
}
// SecuritySecretsColumns holds the columns for the "security_secrets" table.
SecuritySecretsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "key", Type: field.TypeString, Unique: true, Size: 100},
{Name: "value", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
}
// SecuritySecretsTable holds the schema information for the "security_secrets" table.
SecuritySecretsTable = &schema.Table{
Name: "security_secrets",
Columns: SecuritySecretsColumns,
PrimaryKey: []*schema.Column{SecuritySecretsColumns[0]},
}
// SettingsColumns holds the columns for the "settings" table.
SettingsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -650,6 +674,7 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
@@ -666,31 +691,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -699,32 +724,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[31]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[32]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_model",
@@ -739,12 +764,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
},
},
}
@@ -1000,6 +1025,7 @@ var (
PromoCodeUsagesTable,
ProxiesTable,
RedeemCodesTable,
SecuritySecretsTable,
SettingsTable,
UsageCleanupTasksTable,
UsageLogsTable,
@@ -1056,6 +1082,9 @@ func init() {
RedeemCodesTable.Annotation = &entsql.Annotation{
Table: "redeem_codes",
}
SecuritySecretsTable.Annotation = &entsql.Annotation{
Table: "security_secrets",
}
SettingsTable.Annotation = &entsql.Annotation{
Table: "settings",
}

File diff suppressed because it is too large Load Diff

View File

@@ -39,6 +39,9 @@ type Proxy func(*sql.Selector)
// RedeemCode is the predicate function for redeemcode builders.
type RedeemCode func(*sql.Selector)
// SecuritySecret is the predicate function for securitysecret builders.
type SecuritySecret func(*sql.Selector)
// Setting is the predicate function for setting builders.
type Setting func(*sql.Selector)

View File

@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -93,11 +94,11 @@ func init() {
// apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save.
apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error)
// apikeyDescQuota is the schema descriptor for quota field.
apikeyDescQuota := apikeyFields[7].Descriptor()
apikeyDescQuota := apikeyFields[8].Descriptor()
// apikey.DefaultQuota holds the default value on creation for the quota field.
apikey.DefaultQuota = apikeyDescQuota.Default.(float64)
// apikeyDescQuotaUsed is the schema descriptor for quota_used field.
apikeyDescQuotaUsed := apikeyFields[8].Descriptor()
apikeyDescQuotaUsed := apikeyFields[9].Descriptor()
// apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field.
apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64)
accountMixin := schema.Account{}.Mixin()
@@ -398,23 +399,23 @@ func init() {
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
groupDescClaudeCodeOnly := groupFields[18].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
groupDescModelRoutingEnabled := groupFields[18].Descriptor()
groupDescModelRoutingEnabled := groupFields[22].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
groupDescMcpXMLInject := groupFields[19].Descriptor()
groupDescMcpXMLInject := groupFields[23].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
groupDescSupportedModelScopes := groupFields[20].Descriptor()
groupDescSupportedModelScopes := groupFields[24].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field.
groupDescSortOrder := groupFields[21].Descriptor()
groupDescSortOrder := groupFields[25].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
promocodeFields := schema.PromoCode{}.Fields()
@@ -602,6 +603,43 @@ func init() {
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
securitysecretMixin := schema.SecuritySecret{}.Mixin()
securitysecretMixinFields0 := securitysecretMixin[0].Fields()
_ = securitysecretMixinFields0
securitysecretFields := schema.SecuritySecret{}.Fields()
_ = securitysecretFields
// securitysecretDescCreatedAt is the schema descriptor for created_at field.
securitysecretDescCreatedAt := securitysecretMixinFields0[0].Descriptor()
// securitysecret.DefaultCreatedAt holds the default value on creation for the created_at field.
securitysecret.DefaultCreatedAt = securitysecretDescCreatedAt.Default.(func() time.Time)
// securitysecretDescUpdatedAt is the schema descriptor for updated_at field.
securitysecretDescUpdatedAt := securitysecretMixinFields0[1].Descriptor()
// securitysecret.DefaultUpdatedAt holds the default value on creation for the updated_at field.
securitysecret.DefaultUpdatedAt = securitysecretDescUpdatedAt.Default.(func() time.Time)
// securitysecret.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
securitysecret.UpdateDefaultUpdatedAt = securitysecretDescUpdatedAt.UpdateDefault.(func() time.Time)
// securitysecretDescKey is the schema descriptor for key field.
securitysecretDescKey := securitysecretFields[0].Descriptor()
// securitysecret.KeyValidator is a validator for the "key" field. It is called by the builders before save.
securitysecret.KeyValidator = func() func(string) error {
validators := securitysecretDescKey.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
}
return func(key string) error {
for _, fn := range fns {
if err := fn(key); err != nil {
return err
}
}
return nil
}
}()
// securitysecretDescValue is the schema descriptor for value field.
securitysecretDescValue := securitysecretFields[1].Descriptor()
// securitysecret.ValueValidator is a validator for the "value" field. It is called by the builders before save.
securitysecret.ValueValidator = securitysecretDescValue.Validators[0].(func(string) error)
settingFields := schema.Setting{}.Fields()
_ = settingFields
// settingDescKey is the schema descriptor for key field.
@@ -779,12 +817,16 @@ func init() {
usagelogDescImageSize := usagelogFields[28].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[29].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[29].Descriptor()
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[30].Descriptor()
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -47,6 +47,10 @@ func (APIKey) Fields() []ent.Field {
field.String("status").
MaxLen(20).
Default(domain.StatusActive),
field.Time("last_used_at").
Optional().
Nillable().
Comment("Last usage time of this API key"),
field.JSON("ip_whitelist", []string{}).
Optional().
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
@@ -95,6 +99,7 @@ func (APIKey) Indexes() []ent.Index {
index.Fields("group_id"),
index.Fields("status"),
index.Fields("deleted_at"),
index.Fields("last_used_at"),
// Index for quota queries
index.Fields("quota", "quota_used"),
index.Fields("expires_at"),

View File

@@ -87,6 +87,24 @@ func (Group) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Sora 按次计费配置(阶段 1
field.Float("sora_image_price_360").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_image_price_540").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_video_price_per_request").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
field.Float("sora_video_price_per_request_hd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).

View File

@@ -0,0 +1,50 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
)
// IdempotencyRecord 幂等请求记录表。
type IdempotencyRecord struct {
ent.Schema
}
func (IdempotencyRecord) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "idempotency_records"},
}
}
func (IdempotencyRecord) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (IdempotencyRecord) Fields() []ent.Field {
return []ent.Field{
field.String("scope").MaxLen(128),
field.String("idempotency_key_hash").MaxLen(64),
field.String("request_fingerprint").MaxLen(64),
field.String("status").MaxLen(32),
field.Int("response_status").Optional().Nillable(),
field.String("response_body").Optional().Nillable(),
field.String("error_reason").MaxLen(128).Optional().Nillable(),
field.Time("locked_until").Optional().Nillable(),
field.Time("expires_at"),
}
}
func (IdempotencyRecord) Indexes() []ent.Index {
return []ent.Index{
index.Fields("scope", "idempotency_key_hash").Unique(),
index.Fields("expires_at"),
index.Fields("status", "locked_until"),
}
}

View File

@@ -0,0 +1,42 @@
package schema
import (
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/field"
)
// SecuritySecret 存储系统级安全密钥(如 JWT 签名密钥、TOTP 加密密钥)。
type SecuritySecret struct {
ent.Schema
}
func (SecuritySecret) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "security_secrets"},
}
}
func (SecuritySecret) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
}
}
func (SecuritySecret) Fields() []ent.Field {
return []ent.Field{
field.String("key").
MaxLen(100).
NotEmpty().
Unique(),
field.String("value").
NotEmpty().
SchemaType(map[string]string{
dialect.Postgres: "text",
}),
}
}

View File

@@ -118,6 +118,11 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(10).
Optional().
Nillable(),
// 媒体类型字段sora 使用)
field.String("media_type").
MaxLen(16).
Optional().
Nillable(),
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").

View File

@@ -0,0 +1,139 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecret is the model entity for the SecuritySecret schema.
type SecuritySecret struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// Key holds the value of the "key" field.
Key string `json:"key,omitempty"`
// Value holds the value of the "value" field.
Value string `json:"value,omitempty"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*SecuritySecret) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case securitysecret.FieldID:
values[i] = new(sql.NullInt64)
case securitysecret.FieldKey, securitysecret.FieldValue:
values[i] = new(sql.NullString)
case securitysecret.FieldCreatedAt, securitysecret.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the SecuritySecret fields.
func (_m *SecuritySecret) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case securitysecret.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case securitysecret.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case securitysecret.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case securitysecret.FieldKey:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field key", values[i])
} else if value.Valid {
_m.Key = value.String
}
case securitysecret.FieldValue:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field value", values[i])
} else if value.Valid {
_m.Value = value.String
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// GetValue returns the ent.Value that was dynamically selected and assigned to the SecuritySecret.
// This includes values selected through modifiers, order, etc.
func (_m *SecuritySecret) GetValue(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// Update returns a builder for updating this SecuritySecret.
// Note that you need to call SecuritySecret.Unwrap() before calling this method if this SecuritySecret
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *SecuritySecret) Update() *SecuritySecretUpdateOne {
return NewSecuritySecretClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the SecuritySecret entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *SecuritySecret) Unwrap() *SecuritySecret {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: SecuritySecret is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *SecuritySecret) String() string {
var builder strings.Builder
builder.WriteString("SecuritySecret(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("key=")
builder.WriteString(_m.Key)
builder.WriteString(", ")
builder.WriteString("value=")
builder.WriteString(_m.Value)
builder.WriteByte(')')
return builder.String()
}
// SecuritySecrets is a parsable slice of SecuritySecret.
type SecuritySecrets []*SecuritySecret

View File

@@ -0,0 +1,86 @@
// Code generated by ent, DO NOT EDIT.
package securitysecret
import (
"time"
"entgo.io/ent/dialect/sql"
)
const (
// Label holds the string label denoting the securitysecret type in the database.
Label = "security_secret"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldKey holds the string denoting the key field in the database.
FieldKey = "key"
// FieldValue holds the string denoting the value field in the database.
FieldValue = "value"
// Table holds the table name of the securitysecret in the database.
Table = "security_secrets"
)
// Columns holds all SQL columns for securitysecret fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldKey,
FieldValue,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// KeyValidator is a validator for the "key" field. It is called by the builders before save.
KeyValidator func(string) error
// ValueValidator is a validator for the "value" field. It is called by the builders before save.
ValueValidator func(string) error
)
// OrderOption defines the ordering options for the SecuritySecret queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByKey orders the results by the key field.
func ByKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldKey, opts...).ToFunc()
}
// ByValue orders the results by the value field.
func ByValue(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldValue, opts...).ToFunc()
}

View File

@@ -0,0 +1,300 @@
// Code generated by ent, DO NOT EDIT.
package securitysecret
import (
"time"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v))
}
// Key applies equality check predicate on the "key" field. It's identical to KeyEQ.
func Key(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v))
}
// Value applies equality check predicate on the "value" field. It's identical to ValueEQ.
func Value(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldUpdatedAt, v))
}
// KeyEQ applies the EQ predicate on the "key" field.
func KeyEQ(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v))
}
// KeyNEQ applies the NEQ predicate on the "key" field.
func KeyNEQ(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldKey, v))
}
// KeyIn applies the In predicate on the "key" field.
func KeyIn(vs ...string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldKey, vs...))
}
// KeyNotIn applies the NotIn predicate on the "key" field.
func KeyNotIn(vs ...string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldKey, vs...))
}
// KeyGT applies the GT predicate on the "key" field.
func KeyGT(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldKey, v))
}
// KeyGTE applies the GTE predicate on the "key" field.
func KeyGTE(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldKey, v))
}
// KeyLT applies the LT predicate on the "key" field.
func KeyLT(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldKey, v))
}
// KeyLTE applies the LTE predicate on the "key" field.
func KeyLTE(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldKey, v))
}
// KeyContains applies the Contains predicate on the "key" field.
func KeyContains(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldContains(FieldKey, v))
}
// KeyHasPrefix applies the HasPrefix predicate on the "key" field.
func KeyHasPrefix(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldHasPrefix(FieldKey, v))
}
// KeyHasSuffix applies the HasSuffix predicate on the "key" field.
func KeyHasSuffix(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldHasSuffix(FieldKey, v))
}
// KeyEqualFold applies the EqualFold predicate on the "key" field.
func KeyEqualFold(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEqualFold(FieldKey, v))
}
// KeyContainsFold applies the ContainsFold predicate on the "key" field.
func KeyContainsFold(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldContainsFold(FieldKey, v))
}
// ValueEQ applies the EQ predicate on the "value" field.
func ValueEQ(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v))
}
// ValueNEQ applies the NEQ predicate on the "value" field.
func ValueNEQ(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNEQ(FieldValue, v))
}
// ValueIn applies the In predicate on the "value" field.
func ValueIn(vs ...string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldIn(FieldValue, vs...))
}
// ValueNotIn applies the NotIn predicate on the "value" field.
func ValueNotIn(vs ...string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldNotIn(FieldValue, vs...))
}
// ValueGT applies the GT predicate on the "value" field.
func ValueGT(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGT(FieldValue, v))
}
// ValueGTE applies the GTE predicate on the "value" field.
func ValueGTE(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldGTE(FieldValue, v))
}
// ValueLT applies the LT predicate on the "value" field.
func ValueLT(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLT(FieldValue, v))
}
// ValueLTE applies the LTE predicate on the "value" field.
func ValueLTE(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldLTE(FieldValue, v))
}
// ValueContains applies the Contains predicate on the "value" field.
func ValueContains(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldContains(FieldValue, v))
}
// ValueHasPrefix applies the HasPrefix predicate on the "value" field.
func ValueHasPrefix(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldHasPrefix(FieldValue, v))
}
// ValueHasSuffix applies the HasSuffix predicate on the "value" field.
func ValueHasSuffix(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldHasSuffix(FieldValue, v))
}
// ValueEqualFold applies the EqualFold predicate on the "value" field.
func ValueEqualFold(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldEqualFold(FieldValue, v))
}
// ValueContainsFold applies the ContainsFold predicate on the "value" field.
func ValueContainsFold(v string) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.FieldContainsFold(FieldValue, v))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.SecuritySecret) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.SecuritySecret) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.SecuritySecret) predicate.SecuritySecret {
return predicate.SecuritySecret(sql.NotPredicates(p))
}

View File

@@ -0,0 +1,626 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretCreate is the builder for creating a SecuritySecret entity.
type SecuritySecretCreate struct {
config
mutation *SecuritySecretMutation
hooks []Hook
conflict []sql.ConflictOption
}
// SetCreatedAt sets the "created_at" field.
func (_c *SecuritySecretCreate) SetCreatedAt(v time.Time) *SecuritySecretCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_c *SecuritySecretCreate) SetNillableCreatedAt(v *time.Time) *SecuritySecretCreate {
if v != nil {
_c.SetCreatedAt(*v)
}
return _c
}
// SetUpdatedAt sets the "updated_at" field.
func (_c *SecuritySecretCreate) SetUpdatedAt(v time.Time) *SecuritySecretCreate {
_c.mutation.SetUpdatedAt(v)
return _c
}
// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
func (_c *SecuritySecretCreate) SetNillableUpdatedAt(v *time.Time) *SecuritySecretCreate {
if v != nil {
_c.SetUpdatedAt(*v)
}
return _c
}
// SetKey sets the "key" field.
func (_c *SecuritySecretCreate) SetKey(v string) *SecuritySecretCreate {
_c.mutation.SetKey(v)
return _c
}
// SetValue sets the "value" field.
func (_c *SecuritySecretCreate) SetValue(v string) *SecuritySecretCreate {
_c.mutation.SetValue(v)
return _c
}
// Mutation returns the SecuritySecretMutation object of the builder.
func (_c *SecuritySecretCreate) Mutation() *SecuritySecretMutation {
return _c.mutation
}
// Save creates the SecuritySecret in the database.
func (_c *SecuritySecretCreate) Save(ctx context.Context) (*SecuritySecret, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *SecuritySecretCreate) SaveX(ctx context.Context) *SecuritySecret {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *SecuritySecretCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *SecuritySecretCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *SecuritySecretCreate) defaults() {
if _, ok := _c.mutation.CreatedAt(); !ok {
v := securitysecret.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
v := securitysecret.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *SecuritySecretCreate) check() error {
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SecuritySecret.created_at"`)}
}
if _, ok := _c.mutation.UpdatedAt(); !ok {
return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "SecuritySecret.updated_at"`)}
}
if _, ok := _c.mutation.Key(); !ok {
return &ValidationError{Name: "key", err: errors.New(`ent: missing required field "SecuritySecret.key"`)}
}
if v, ok := _c.mutation.Key(); ok {
if err := securitysecret.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)}
}
}
if _, ok := _c.mutation.Value(); !ok {
return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "SecuritySecret.value"`)}
}
if v, ok := _c.mutation.Value(); ok {
if err := securitysecret.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)}
}
}
return nil
}
func (_c *SecuritySecretCreate) sqlSave(ctx context.Context) (*SecuritySecret, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
id := _spec.ID.Value.(int64)
_node.ID = int64(id)
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *SecuritySecretCreate) createSpec() (*SecuritySecret, *sqlgraph.CreateSpec) {
var (
_node = &SecuritySecret{config: _c.config}
_spec = sqlgraph.NewCreateSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
)
_spec.OnConflict = _c.conflict
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(securitysecret.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.UpdatedAt(); ok {
_spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = value
}
if value, ok := _c.mutation.Key(); ok {
_spec.SetField(securitysecret.FieldKey, field.TypeString, value)
_node.Key = value
}
if value, ok := _c.mutation.Value(); ok {
_spec.SetField(securitysecret.FieldValue, field.TypeString, value)
_node.Value = value
}
return _node, _spec
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.SecuritySecret.Create().
// SetCreatedAt(v).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.SecuritySecretUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *SecuritySecretCreate) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertOne {
_c.conflict = opts
return &SecuritySecretUpsertOne{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *SecuritySecretCreate) OnConflictColumns(columns ...string) *SecuritySecretUpsertOne {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &SecuritySecretUpsertOne{
create: _c,
}
}
type (
// SecuritySecretUpsertOne is the builder for "upsert"-ing
// one SecuritySecret node.
SecuritySecretUpsertOne struct {
create *SecuritySecretCreate
}
// SecuritySecretUpsert is the "OnConflict" setter.
SecuritySecretUpsert struct {
*sql.UpdateSet
}
)
// SetUpdatedAt sets the "updated_at" field.
func (u *SecuritySecretUpsert) SetUpdatedAt(v time.Time) *SecuritySecretUpsert {
u.Set(securitysecret.FieldUpdatedAt, v)
return u
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *SecuritySecretUpsert) UpdateUpdatedAt() *SecuritySecretUpsert {
u.SetExcluded(securitysecret.FieldUpdatedAt)
return u
}
// SetKey sets the "key" field.
func (u *SecuritySecretUpsert) SetKey(v string) *SecuritySecretUpsert {
u.Set(securitysecret.FieldKey, v)
return u
}
// UpdateKey sets the "key" field to the value that was provided on create.
func (u *SecuritySecretUpsert) UpdateKey() *SecuritySecretUpsert {
u.SetExcluded(securitysecret.FieldKey)
return u
}
// SetValue sets the "value" field.
func (u *SecuritySecretUpsert) SetValue(v string) *SecuritySecretUpsert {
u.Set(securitysecret.FieldValue, v)
return u
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *SecuritySecretUpsert) UpdateValue() *SecuritySecretUpsert {
u.SetExcluded(securitysecret.FieldValue)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *SecuritySecretUpsertOne) UpdateNewValues() *SecuritySecretUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
if _, exists := u.create.mutation.CreatedAt(); exists {
s.SetIgnore(securitysecret.FieldCreatedAt)
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *SecuritySecretUpsertOne) Ignore() *SecuritySecretUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *SecuritySecretUpsertOne) DoNothing() *SecuritySecretUpsertOne {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreate.OnConflict
// documentation for more info.
func (u *SecuritySecretUpsertOne) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertOne {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&SecuritySecretUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *SecuritySecretUpsertOne) SetUpdatedAt(v time.Time) *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *SecuritySecretUpsertOne) UpdateUpdatedAt() *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateUpdatedAt()
})
}
// SetKey sets the "key" field.
func (u *SecuritySecretUpsertOne) SetKey(v string) *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetKey(v)
})
}
// UpdateKey sets the "key" field to the value that was provided on create.
func (u *SecuritySecretUpsertOne) UpdateKey() *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateKey()
})
}
// SetValue sets the "value" field.
func (u *SecuritySecretUpsertOne) SetValue(v string) *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetValue(v)
})
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *SecuritySecretUpsertOne) UpdateValue() *SecuritySecretUpsertOne {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateValue()
})
}
// Exec executes the query.
func (u *SecuritySecretUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for SecuritySecretCreate.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *SecuritySecretUpsertOne) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}
// Exec executes the UPSERT query and returns the inserted/updated ID.
func (u *SecuritySecretUpsertOne) ID(ctx context.Context) (id int64, err error) {
node, err := u.create.Save(ctx)
if err != nil {
return id, err
}
return node.ID, nil
}
// IDX is like ID, but panics if an error occurs.
func (u *SecuritySecretUpsertOne) IDX(ctx context.Context) int64 {
id, err := u.ID(ctx)
if err != nil {
panic(err)
}
return id
}
// SecuritySecretCreateBulk is the builder for creating many SecuritySecret entities in bulk.
type SecuritySecretCreateBulk struct {
config
err error
builders []*SecuritySecretCreate
conflict []sql.ConflictOption
}
// Save creates the SecuritySecret entities in the database.
func (_c *SecuritySecretCreateBulk) Save(ctx context.Context) ([]*SecuritySecret, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*SecuritySecret, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*SecuritySecretMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
spec.OnConflict = _c.conflict
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
if specs[i].ID.Value != nil {
id := specs[i].ID.Value.(int64)
nodes[i].ID = int64(id)
}
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *SecuritySecretCreateBulk) SaveX(ctx context.Context) []*SecuritySecret {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *SecuritySecretCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *SecuritySecretCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
// of the `INSERT` statement. For example:
//
// client.SecuritySecret.CreateBulk(builders...).
// OnConflict(
// // Update the row with the new values
// // the was proposed for insertion.
// sql.ResolveWithNewValues(),
// ).
// // Override some of the fields with custom
// // update values.
// Update(func(u *ent.SecuritySecretUpsert) {
// SetCreatedAt(v+v).
// }).
// Exec(ctx)
func (_c *SecuritySecretCreateBulk) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertBulk {
_c.conflict = opts
return &SecuritySecretUpsertBulk{
create: _c,
}
}
// OnConflictColumns calls `OnConflict` and configures the columns
// as conflict target. Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(sql.ConflictColumns(columns...)).
// Exec(ctx)
func (_c *SecuritySecretCreateBulk) OnConflictColumns(columns ...string) *SecuritySecretUpsertBulk {
_c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
return &SecuritySecretUpsertBulk{
create: _c,
}
}
// SecuritySecretUpsertBulk is the builder for "upsert"-ing
// a bulk of SecuritySecret nodes.
type SecuritySecretUpsertBulk struct {
create *SecuritySecretCreateBulk
}
// UpdateNewValues updates the mutable fields using the new values that
// were set on create. Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(
// sql.ResolveWithNewValues(),
// ).
// Exec(ctx)
func (u *SecuritySecretUpsertBulk) UpdateNewValues() *SecuritySecretUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
for _, b := range u.create.builders {
if _, exists := b.mutation.CreatedAt(); exists {
s.SetIgnore(securitysecret.FieldCreatedAt)
}
}
}))
return u
}
// Ignore sets each column to itself in case of conflict.
// Using this option is equivalent to using:
//
// client.SecuritySecret.Create().
// OnConflict(sql.ResolveWithIgnore()).
// Exec(ctx)
func (u *SecuritySecretUpsertBulk) Ignore() *SecuritySecretUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
return u
}
// DoNothing configures the conflict_action to `DO NOTHING`.
// Supported only by SQLite and PostgreSQL.
func (u *SecuritySecretUpsertBulk) DoNothing() *SecuritySecretUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.DoNothing())
return u
}
// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreateBulk.OnConflict
// documentation for more info.
func (u *SecuritySecretUpsertBulk) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertBulk {
u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
set(&SecuritySecretUpsert{UpdateSet: update})
}))
return u
}
// SetUpdatedAt sets the "updated_at" field.
func (u *SecuritySecretUpsertBulk) SetUpdatedAt(v time.Time) *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetUpdatedAt(v)
})
}
// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
func (u *SecuritySecretUpsertBulk) UpdateUpdatedAt() *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateUpdatedAt()
})
}
// SetKey sets the "key" field.
func (u *SecuritySecretUpsertBulk) SetKey(v string) *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetKey(v)
})
}
// UpdateKey sets the "key" field to the value that was provided on create.
func (u *SecuritySecretUpsertBulk) UpdateKey() *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateKey()
})
}
// SetValue sets the "value" field.
func (u *SecuritySecretUpsertBulk) SetValue(v string) *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.SetValue(v)
})
}
// UpdateValue sets the "value" field to the value that was provided on create.
func (u *SecuritySecretUpsertBulk) UpdateValue() *SecuritySecretUpsertBulk {
return u.Update(func(s *SecuritySecretUpsert) {
s.UpdateValue()
})
}
// Exec executes the query.
func (u *SecuritySecretUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
return u.create.err
}
for i, b := range u.create.builders {
if len(b.conflict) != 0 {
return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SecuritySecretCreateBulk instead", i)
}
}
if len(u.create.conflict) == 0 {
return errors.New("ent: missing options for SecuritySecretCreateBulk.OnConflict")
}
return u.create.Exec(ctx)
}
// ExecX is like Exec, but panics if an error occurs.
func (u *SecuritySecretUpsertBulk) ExecX(ctx context.Context) {
if err := u.create.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretDelete is the builder for deleting a SecuritySecret entity.
type SecuritySecretDelete struct {
config
hooks []Hook
mutation *SecuritySecretMutation
}
// Where appends a list predicates to the SecuritySecretDelete builder.
func (_d *SecuritySecretDelete) Where(ps ...predicate.SecuritySecret) *SecuritySecretDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *SecuritySecretDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *SecuritySecretDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *SecuritySecretDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// SecuritySecretDeleteOne is the builder for deleting a single SecuritySecret entity.
type SecuritySecretDeleteOne struct {
_d *SecuritySecretDelete
}
// Where appends a list predicates to the SecuritySecretDelete builder.
func (_d *SecuritySecretDeleteOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *SecuritySecretDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{securitysecret.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *SecuritySecretDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,564 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretQuery is the builder for querying SecuritySecret entities.
type SecuritySecretQuery struct {
config
ctx *QueryContext
order []securitysecret.OrderOption
inters []Interceptor
predicates []predicate.SecuritySecret
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the SecuritySecretQuery builder.
func (_q *SecuritySecretQuery) Where(ps ...predicate.SecuritySecret) *SecuritySecretQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *SecuritySecretQuery) Limit(limit int) *SecuritySecretQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *SecuritySecretQuery) Offset(offset int) *SecuritySecretQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *SecuritySecretQuery) Unique(unique bool) *SecuritySecretQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *SecuritySecretQuery) Order(o ...securitysecret.OrderOption) *SecuritySecretQuery {
_q.order = append(_q.order, o...)
return _q
}
// First returns the first SecuritySecret entity from the query.
// Returns a *NotFoundError when no SecuritySecret was found.
func (_q *SecuritySecretQuery) First(ctx context.Context) (*SecuritySecret, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{securitysecret.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *SecuritySecretQuery) FirstX(ctx context.Context) *SecuritySecret {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first SecuritySecret ID from the query.
// Returns a *NotFoundError when no SecuritySecret ID was found.
func (_q *SecuritySecretQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{securitysecret.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *SecuritySecretQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single SecuritySecret entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one SecuritySecret entity is found.
// Returns a *NotFoundError when no SecuritySecret entities are found.
func (_q *SecuritySecretQuery) Only(ctx context.Context) (*SecuritySecret, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{securitysecret.Label}
default:
return nil, &NotSingularError{securitysecret.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *SecuritySecretQuery) OnlyX(ctx context.Context) *SecuritySecret {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only SecuritySecret ID in the query.
// Returns a *NotSingularError when more than one SecuritySecret ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *SecuritySecretQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{securitysecret.Label}
default:
err = &NotSingularError{securitysecret.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *SecuritySecretQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of SecuritySecrets.
func (_q *SecuritySecretQuery) All(ctx context.Context) ([]*SecuritySecret, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*SecuritySecret, *SecuritySecretQuery]()
return withInterceptors[[]*SecuritySecret](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *SecuritySecretQuery) AllX(ctx context.Context) []*SecuritySecret {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of SecuritySecret IDs.
func (_q *SecuritySecretQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(securitysecret.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *SecuritySecretQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *SecuritySecretQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*SecuritySecretQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *SecuritySecretQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *SecuritySecretQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *SecuritySecretQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the SecuritySecretQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *SecuritySecretQuery) Clone() *SecuritySecretQuery {
if _q == nil {
return nil
}
return &SecuritySecretQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]securitysecret.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.SecuritySecret{}, _q.predicates...),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.SecuritySecret.Query().
// GroupBy(securitysecret.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *SecuritySecretQuery) GroupBy(field string, fields ...string) *SecuritySecretGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &SecuritySecretGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = securitysecret.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.SecuritySecret.Query().
// Select(securitysecret.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *SecuritySecretQuery) Select(fields ...string) *SecuritySecretSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &SecuritySecretSelect{SecuritySecretQuery: _q}
sbuild.label = securitysecret.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a SecuritySecretSelect configured with the given aggregations.
func (_q *SecuritySecretQuery) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *SecuritySecretQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !securitysecret.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *SecuritySecretQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SecuritySecret, error) {
var (
nodes = []*SecuritySecret{}
_spec = _q.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*SecuritySecret).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &SecuritySecret{config: _q.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (_q *SecuritySecretQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *SecuritySecretQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID)
for i := range fields {
if fields[i] != securitysecret.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *SecuritySecretQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(securitysecret.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = securitysecret.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *SecuritySecretQuery) ForUpdate(opts ...sql.LockOption) *SecuritySecretQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *SecuritySecretQuery) ForShare(opts ...sql.LockOption) *SecuritySecretQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// SecuritySecretGroupBy is the group-by builder for SecuritySecret entities.
type SecuritySecretGroupBy struct {
selector
build *SecuritySecretQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *SecuritySecretGroupBy) Aggregate(fns ...AggregateFunc) *SecuritySecretGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *SecuritySecretGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *SecuritySecretGroupBy) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// SecuritySecretSelect is the builder for selecting fields of SecuritySecret entities.
type SecuritySecretSelect struct {
*SecuritySecretQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *SecuritySecretSelect) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *SecuritySecretSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretSelect](ctx, _s.SecuritySecretQuery, _s, _s.inters, v)
}
func (_s *SecuritySecretSelect) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@@ -0,0 +1,316 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
)
// SecuritySecretUpdate is the builder for updating SecuritySecret entities.
type SecuritySecretUpdate struct {
config
hooks []Hook
mutation *SecuritySecretMutation
}
// Where appends a list predicates to the SecuritySecretUpdate builder.
func (_u *SecuritySecretUpdate) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *SecuritySecretUpdate) SetUpdatedAt(v time.Time) *SecuritySecretUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetKey sets the "key" field.
func (_u *SecuritySecretUpdate) SetKey(v string) *SecuritySecretUpdate {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *SecuritySecretUpdate) SetNillableKey(v *string) *SecuritySecretUpdate {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *SecuritySecretUpdate) SetValue(v string) *SecuritySecretUpdate {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *SecuritySecretUpdate) SetNillableValue(v *string) *SecuritySecretUpdate {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// Mutation returns the SecuritySecretMutation object of the builder.
func (_u *SecuritySecretUpdate) Mutation() *SecuritySecretMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *SecuritySecretUpdate) Save(ctx context.Context) (int, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *SecuritySecretUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *SecuritySecretUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *SecuritySecretUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *SecuritySecretUpdate) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := securitysecret.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *SecuritySecretUpdate) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := securitysecret.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)}
}
}
if v, ok := _u.mutation.Value(); ok {
if err := securitysecret.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)}
}
}
return nil
}
func (_u *SecuritySecretUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(securitysecret.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(securitysecret.FieldValue, field.TypeString, value)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{securitysecret.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// SecuritySecretUpdateOne is the builder for updating a single SecuritySecret entity.
type SecuritySecretUpdateOne struct {
config
fields []string
hooks []Hook
mutation *SecuritySecretMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *SecuritySecretUpdateOne) SetUpdatedAt(v time.Time) *SecuritySecretUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetKey sets the "key" field.
func (_u *SecuritySecretUpdateOne) SetKey(v string) *SecuritySecretUpdateOne {
_u.mutation.SetKey(v)
return _u
}
// SetNillableKey sets the "key" field if the given value is not nil.
func (_u *SecuritySecretUpdateOne) SetNillableKey(v *string) *SecuritySecretUpdateOne {
if v != nil {
_u.SetKey(*v)
}
return _u
}
// SetValue sets the "value" field.
func (_u *SecuritySecretUpdateOne) SetValue(v string) *SecuritySecretUpdateOne {
_u.mutation.SetValue(v)
return _u
}
// SetNillableValue sets the "value" field if the given value is not nil.
func (_u *SecuritySecretUpdateOne) SetNillableValue(v *string) *SecuritySecretUpdateOne {
if v != nil {
_u.SetValue(*v)
}
return _u
}
// Mutation returns the SecuritySecretMutation object of the builder.
func (_u *SecuritySecretUpdateOne) Mutation() *SecuritySecretMutation {
return _u.mutation
}
// Where appends a list predicates to the SecuritySecretUpdate builder.
func (_u *SecuritySecretUpdateOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *SecuritySecretUpdateOne) Select(field string, fields ...string) *SecuritySecretUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated SecuritySecret entity.
func (_u *SecuritySecretUpdateOne) Save(ctx context.Context) (*SecuritySecret, error) {
_u.defaults()
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *SecuritySecretUpdateOne) SaveX(ctx context.Context) *SecuritySecret {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *SecuritySecretUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *SecuritySecretUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *SecuritySecretUpdateOne) defaults() {
if _, ok := _u.mutation.UpdatedAt(); !ok {
v := securitysecret.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_u *SecuritySecretUpdateOne) check() error {
if v, ok := _u.mutation.Key(); ok {
if err := securitysecret.KeyValidator(v); err != nil {
return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)}
}
}
if v, ok := _u.mutation.Value(); ok {
if err := securitysecret.ValueValidator(v); err != nil {
return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)}
}
}
return nil
}
func (_u *SecuritySecretUpdateOne) sqlSave(ctx context.Context) (_node *SecuritySecret, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SecuritySecret.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID)
for _, f := range fields {
if !securitysecret.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != securitysecret.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.Key(); ok {
_spec.SetField(securitysecret.FieldKey, field.TypeString, value)
}
if value, ok := _u.mutation.Value(); ok {
_spec.SetField(securitysecret.FieldValue, field.TypeString, value)
}
_node = &SecuritySecret{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{securitysecret.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@@ -36,6 +36,8 @@ type Tx struct {
Proxy *ProxyClient
// RedeemCode is the client for interacting with the RedeemCode builders.
RedeemCode *RedeemCodeClient
// SecuritySecret is the client for interacting with the SecuritySecret builders.
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
@@ -194,6 +196,7 @@ func (tx *Tx) init() {
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config)

View File

@@ -80,6 +80,8 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"`
// MediaType holds the value of the "media_type" field.
MediaType *string `json:"media_type,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
@@ -173,7 +175,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -380,6 +382,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string)
*_m.ImageSize = value.String
}
case usagelog.FieldMediaType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field media_type", values[i])
} else if value.Valid {
_m.MediaType = new(string)
*_m.MediaType = value.String
}
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
@@ -556,6 +565,11 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.MediaType; v != nil {
builder.WriteString("media_type=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")

View File

@@ -72,6 +72,8 @@ const (
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size"
// FieldMediaType holds the string denoting the media_type field in the database.
FieldMediaType = "media_type"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
@@ -157,6 +159,7 @@ var Columns = []string{
FieldIPAddress,
FieldImageCount,
FieldImageSize,
FieldMediaType,
FieldCacheTTLOverridden,
FieldCreatedAt,
}
@@ -214,6 +217,8 @@ var (
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
MediaTypeValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
@@ -373,6 +378,11 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
}
// ByMediaType orders the results by the media_type field.
func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
}
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()

View File

@@ -200,6 +200,11 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
}
// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
func MediaType(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
@@ -1445,6 +1450,81 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
}
// MediaTypeEQ applies the EQ predicate on the "media_type" field.
func MediaTypeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
func MediaTypeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
}
// MediaTypeIn applies the In predicate on the "media_type" field.
func MediaTypeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
}
// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
func MediaTypeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
}
// MediaTypeGT applies the GT predicate on the "media_type" field.
func MediaTypeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
}
// MediaTypeGTE applies the GTE predicate on the "media_type" field.
func MediaTypeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
}
// MediaTypeLT applies the LT predicate on the "media_type" field.
func MediaTypeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
}
// MediaTypeLTE applies the LTE predicate on the "media_type" field.
func MediaTypeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
}
// MediaTypeContains applies the Contains predicate on the "media_type" field.
func MediaTypeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
}
// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
func MediaTypeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
}
// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
func MediaTypeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
}
// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
func MediaTypeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
}
// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
func MediaTypeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
}
// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
func MediaTypeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
}
// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
func MediaTypeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
}
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))

View File

@@ -393,6 +393,20 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c
}
// SetMediaType sets the "media_type" field.
func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
_c.mutation.SetMediaType(v)
return _c
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
if v != nil {
_c.SetMediaType(*v)
}
return _c
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
@@ -645,6 +659,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _c.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
@@ -783,6 +802,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value
}
if value, ok := _c.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
_node.MediaType = &value
}
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
@@ -1432,6 +1455,24 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u
}
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
u.Set(usagelog.FieldMediaType, v)
return u
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldMediaType)
return u
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
u.SetNull(usagelog.FieldMediaType)
return u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
@@ -2077,6 +2118,27 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
})
}
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetMediaType(v)
})
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateMediaType()
})
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearMediaType()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2890,6 +2952,27 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
})
}
// SetMediaType sets the "media_type" field.
func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetMediaType(v)
})
}
// UpdateMediaType sets the "media_type" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateMediaType()
})
}
// ClearMediaType clears the value of the "media_type" field.
func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearMediaType()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -612,6 +612,26 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u
}
// SetMediaType sets the "media_type" field.
func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
_u.mutation.SetMediaType(v)
return _u
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
if v != nil {
_u.SetMediaType(*v)
}
return _u
}
// ClearMediaType clears the value of the "media_type" field.
func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
_u.mutation.ClearMediaType()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
@@ -740,6 +760,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _u.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -908,6 +933,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
}
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
@@ -1656,6 +1687,26 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u
}
// SetMediaType sets the "media_type" field.
func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
_u.mutation.SetMediaType(v)
return _u
}
// SetNillableMediaType sets the "media_type" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetMediaType(*v)
}
return _u
}
// ClearMediaType clears the value of the "media_type" field.
func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
_u.mutation.ClearMediaType()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
@@ -1797,6 +1848,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _u.mutation.MediaType(); ok {
if err := usagelog.MediaTypeValidator(v); err != nil {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -1982,6 +2038,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.MediaType(); ok {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
}
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}

View File

@@ -5,6 +5,8 @@ go 1.25.7
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alitto/pond/v2 v2.6.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2
@@ -13,9 +15,10 @@ require (
github.com/gorilla/websocket v1.5.3
github.com/imroc/req/v3 v3.57.0
github.com/lib/pq v1.10.9
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pquerna/otp v1.5.0
github.com/redis/go-redis/v9 v9.17.2
github.com/refraction-networking/utls v1.8.1
github.com/refraction-networking/utls v1.8.2
github.com/robfig/cron/v3 v3.0.1
github.com/shirou/gopsutil/v4 v4.25.6
github.com/spf13/viper v1.18.2
@@ -25,10 +28,12 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.47.0
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.39.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
)
@@ -45,7 +50,6 @@ require (
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
@@ -103,7 +107,6 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect

View File

@@ -14,10 +14,14 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
@@ -135,8 +139,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -172,8 +174,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -207,8 +207,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -234,12 +232,10 @@ github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI1
github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -262,8 +258,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@@ -340,10 +334,14 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
@@ -391,6 +389,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -5,7 +5,7 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"log/slog"
"net/url"
"os"
"strings"
@@ -19,6 +19,13 @@ const (
RunModeSimple = "simple"
)
// 使用量记录队列溢出策略
const (
UsageRecordOverflowPolicyDrop = "drop"
UsageRecordOverflowPolicySample = "sample"
UsageRecordOverflowPolicySync = "sync"
)
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
@@ -38,31 +45,68 @@ const (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
CORS CORSConfig `mapstructure:"cors"`
Security SecurityConfig `mapstructure:"security"`
Billing BillingConfig `mapstructure:"billing"`
Turnstile TurnstileConfig `mapstructure:"turnstile"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Sora SoraConfig `mapstructure:"sora"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
Update UpdateConfig `mapstructure:"update"`
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
}
type LogConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
ServiceName string `mapstructure:"service_name"`
Environment string `mapstructure:"env"`
Caller bool `mapstructure:"caller"`
StacktraceLevel string `mapstructure:"stacktrace_level"`
Output LogOutputConfig `mapstructure:"output"`
Rotation LogRotationConfig `mapstructure:"rotation"`
Sampling LogSamplingConfig `mapstructure:"sampling"`
}
type LogOutputConfig struct {
ToStdout bool `mapstructure:"to_stdout"`
ToFile bool `mapstructure:"to_file"`
FilePath string `mapstructure:"file_path"`
}
type LogRotationConfig struct {
MaxSizeMB int `mapstructure:"max_size_mb"`
MaxBackups int `mapstructure:"max_backups"`
MaxAgeDays int `mapstructure:"max_age_days"`
Compress bool `mapstructure:"compress"`
LocalTime bool `mapstructure:"local_time"`
}
type LogSamplingConfig struct {
Enabled bool `mapstructure:"enabled"`
Initial int `mapstructure:"initial"`
Thereafter int `mapstructure:"thereafter"`
}
type GeminiConfig struct {
@@ -94,6 +138,25 @@ type UpdateConfig struct {
ProxyURL string `mapstructure:"proxy_url"`
}
type IdempotencyConfig struct {
// ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。
ObserveOnly bool `mapstructure:"observe_only"`
// DefaultTTLSeconds 关键写接口的幂等记录默认 TTL
DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"`
// SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL
SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"`
// ProcessingTimeoutSeconds processing 状态锁超时(秒)。
ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"`
// FailedRetryBackoffSeconds 失败退避窗口(秒)。
FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"`
// MaxStoredResponseLen 持久化响应体最大长度(字节)。
MaxStoredResponseLen int `mapstructure:"max_stored_response_len"`
// CleanupIntervalSeconds 过期记录清理周期(秒)。
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
// CleanupBatchSize 每次清理的最大记录数。
CleanupBatchSize int `mapstructure:"cleanup_batch_size"`
}
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
@@ -126,6 +189,8 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token默认关闭
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
}
type PricingConfig struct {
@@ -147,6 +212,7 @@ type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL用于生成邮件中的外部链接
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP
@@ -173,6 +239,7 @@ type SecurityConfig struct {
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"`
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
}
@@ -197,6 +264,12 @@ type CSPConfig struct {
Policy string `mapstructure:"policy"`
}
type ProxyFallbackConfig struct {
// AllowDirectOnError 当代理初始化失败时是否允许回退直连。
// 默认 false避免因代理配置错误导致 IP 泄露/关联。
AllowDirectOnError bool `mapstructure:"allow_direct_on_error"`
}
type ProxyProbeConfig struct {
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证
}
@@ -217,6 +290,59 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"`
}
// SoraConfig 直连 Sora 配置
type SoraConfig struct {
Client SoraClientConfig `mapstructure:"client"`
Storage SoraStorageConfig `mapstructure:"storage"`
}
// SoraClientConfig 直连 Sora 客户端配置
type SoraClientConfig struct {
BaseURL string `mapstructure:"base_url"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
MaxRetries int `mapstructure:"max_retries"`
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
RecentTaskLimit int `mapstructure:"recent_task_limit"`
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
Debug bool `mapstructure:"debug"`
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
Headers map[string]string `mapstructure:"headers"`
UserAgent string `mapstructure:"user_agent"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
}
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
type SoraCurlCFFISidecarConfig struct {
Enabled bool `mapstructure:"enabled"`
BaseURL string `mapstructure:"base_url"`
Impersonate string `mapstructure:"impersonate"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
}
// SoraStorageConfig 媒体存储配置
type SoraStorageConfig struct {
Type string `mapstructure:"type"`
LocalPath string `mapstructure:"local_path"`
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
Debug bool `mapstructure:"debug"`
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
}
// SoraStorageCleanupConfig 媒体清理配置
type SoraStorageCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
Schedule string `mapstructure:"schedule"`
RetentionDays int `mapstructure:"retention_days"`
}
// GatewayConfig API网关相关配置
type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时
@@ -224,8 +350,20 @@ type GatewayConfig struct {
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"`
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
// 代理探测响应体读取上限(字节)
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
// ConnectionPoolIsolation: 上游连接池隔离策略proxy/account/account_proxy
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@@ -271,6 +409,24 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Sora 专用配置
// SoraMaxBodySize: Sora 请求体最大字节数0 表示使用 gateway.max_body_size
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
// SoraStreamTimeoutSeconds: Sora 流式请求总超时0 表示不限制)
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
// SoraRequestTimeoutSeconds: Sora 非流式请求超时0 表示不限制)
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
// SoraStreamMode: stream 强制策略force/error
SoraStreamMode string `mapstructure:"sora_stream_mode"`
// SoraModelFilters: 模型列表过滤配置
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
// SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
// SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
// SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数Gemini 平台单独配置,因 API 限制更严格)
@@ -284,6 +440,53 @@ type GatewayConfig struct {
// TLSFingerprint: TLS指纹伪装配置
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
// UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL
UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"`
// ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
WorkerCount int `mapstructure:"worker_count"`
// QueueSize: 队列容量(有界)
QueueSize int `mapstructure:"queue_size"`
// TaskTimeoutSeconds: 单个使用量记录任务超时(秒)
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
// OverflowPolicy: 队列满时策略drop/sample/sync
OverflowPolicy string `mapstructure:"overflow_policy"`
// OverflowSamplePercent: sample 策略下同步回写采样百分比1-100
OverflowSamplePercent int `mapstructure:"overflow_sample_percent"`
// AutoScaleEnabled: 是否启用 worker 自动扩缩容
AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"`
// AutoScaleMinWorkers: 自动扩缩容最小 worker 数
AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"`
// AutoScaleMaxWorkers: 自动扩缩容最大 worker 数
AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"`
// AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容
AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"`
// AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容
AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"`
// AutoScaleUpStep: 每次扩容步长
AutoScaleUpStep int `mapstructure:"auto_scale_up_step"`
// AutoScaleDownStep: 每次缩容步长
AutoScaleDownStep int `mapstructure:"auto_scale_down_step"`
// AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒)
AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"`
// AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒)
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
}
// SoraModelFiltersConfig Sora 模型过滤配置
type SoraModelFiltersConfig struct {
// HidePromptEnhance 是否隐藏 prompt-enhance 模型
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
}
// TLSFingerprintConfig TLS指纹伪装配置
@@ -479,8 +682,9 @@ type OpsMetricsCollectorCacheConfig struct {
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
// AccessTokenExpireMinutes: Access Token有效期分钟默认15分钟
// 短有效期减少被盗用风险配合Refresh Token实现无感续期
// AccessTokenExpireMinutes: Access Token有效期分钟
// - >0: 使用分钟配置(优先级高于 ExpireHour
// - =0: 回退使用 ExpireHour向后兼容旧配置
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
// RefreshTokenExpireDays: Refresh Token有效期默认30天
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
@@ -525,6 +729,20 @@ type APIKeyAuthCacheConfig struct {
Singleflight bool `mapstructure:"singleflight"`
}
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
type SubscriptionCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
}
// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。
// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。
type SubscriptionMaintenanceConfig struct {
WorkerCount int `mapstructure:"worker_count"`
QueueSize int `mapstructure:"queue_size"`
}
// DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct {
// Enabled: 是否启用仪表盘缓存
@@ -588,7 +806,19 @@ func NormalizeRunMode(value string) string {
}
}
// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。
func Load() (*Config, error) {
return load(false)
}
// LoadForBootstrap 读取启动阶段配置。
//
// 启动阶段允许 jwt.secret 先留空,后续由数据库初始化流程补齐并再次完整校验。
func LoadForBootstrap() (*Config, error) {
return load(true)
}
func load(allowMissingJWTSecret bool) (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
@@ -630,6 +860,7 @@ func Load() (*Config, error) {
if cfg.Server.Mode == "" {
cfg.Server.Mode = "debug"
}
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
@@ -648,15 +879,12 @@ func Load() (*Config, error) {
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
if cfg.JWT.Secret == "" {
secret, err := generateJWTSecret(64)
if err != nil {
return nil, fmt.Errorf("generate jwt secret error: %w", err)
}
cfg.JWT.Secret = secret
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
@@ -667,29 +895,39 @@ func Load() (*Config, error) {
}
cfg.Totp.EncryptionKey = key
cfg.Totp.EncryptionKeyConfigured = false
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.")
} else {
cfg.Totp.EncryptionKeyConfigured = true
}
originalJWTSecret := cfg.JWT.Secret
if allowMissingJWTSecret && originalJWTSecret == "" {
// 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。
cfg.JWT.Secret = strings.Repeat("0", 32)
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
if allowMissingJWTSecret && originalJWTSecret == "" {
cfg.JWT.Secret = ""
}
if !cfg.Security.URLAllowlist.Enabled {
log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
}
if !cfg.Security.ResponseHeaders.Enabled {
log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
}
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.")
}
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
cfg.Security.ResponseHeaders.AdditionalAllowed,
cfg.Security.ResponseHeaders.ForceRemove,
slog.Info("response header policy configured",
"additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed,
"force_remove", cfg.Security.ResponseHeaders.ForceRemove,
)
}
@@ -702,7 +940,8 @@ func setDefaults() {
// Server
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.mode", "release")
viper.SetDefault("server.frontend_url", "")
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
@@ -715,6 +954,25 @@ func setDefaults() {
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
// Log
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "console")
viper.SetDefault("log.service_name", "sub2api")
viper.SetDefault("log.env", "production")
viper.SetDefault("log.caller", true)
viper.SetDefault("log.stacktrace_level", "error")
viper.SetDefault("log.output.to_stdout", true)
viper.SetDefault("log.output.to_file", true)
viper.SetDefault("log.output.file_path", "")
viper.SetDefault("log.rotation.max_size_mb", 100)
viper.SetDefault("log.rotation.max_backups", 10)
viper.SetDefault("log.rotation.max_age_days", 7)
viper.SetDefault("log.rotation.compress", true)
viper.SetDefault("log.rotation.local_time", true)
viper.SetDefault("log.sampling.enabled", false)
viper.SetDefault("log.sampling.initial", 100)
viper.SetDefault("log.sampling.thereafter", 100)
// CORS
viper.SetDefault("cors.allowed_origins", []string{})
viper.SetDefault("cors.allow_credentials", true)
@@ -737,7 +995,7 @@ func setDefaults() {
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
viper.SetDefault("security.response_headers.enabled", false)
viper.SetDefault("security.response_headers.enabled", true)
viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true)
@@ -775,9 +1033,9 @@ func setDefaults() {
viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.max_open_conns", 50)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.sslmode", "prefer")
viper.SetDefault("database.max_open_conns", 256)
viper.SetDefault("database.max_idle_conns", 128)
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
@@ -789,8 +1047,8 @@ func setDefaults() {
viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
viper.SetDefault("redis.pool_size", 1024)
viper.SetDefault("redis.min_idle_conns", 128)
viper.SetDefault("redis.enable_tls", false)
// Ops (vNext)
@@ -810,9 +1068,9 @@ func setDefaults() {
// JWT
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
viper.SetDefault("jwt.access_token_expire_minutes", 0) // 0 表示回退到 expire_hour
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
// TOTP
viper.SetDefault("totp.encryption_key", "")
@@ -849,6 +1107,11 @@ func setDefaults() {
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true)
// Subscription auth L1 cache
viper.SetDefault("subscription_cache.l1_size", 16384)
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
viper.SetDefault("subscription_cache.jitter_percent", 10)
// Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
@@ -874,6 +1137,16 @@ func setDefaults() {
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Idempotency
viper.SetDefault("idempotency.observe_only", true)
viper.SetDefault("idempotency.default_ttl_seconds", 86400)
viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600)
viper.SetDefault("idempotency.processing_timeout_seconds", 30)
viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5)
viper.SetDefault("idempotency.max_stored_response_len", 64*1024)
viper.SetDefault("idempotency.cleanup_interval_seconds", 60)
viper.SetDefault("idempotency.cleanup_batch_size", 500)
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true)
@@ -882,14 +1155,26 @@ func setDefaults() {
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
viper.SetDefault("gateway.sora_stream_mode", "force")
viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
viper.SetDefault("gateway.sora_media_require_api_key", true)
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃HTTP/2 场景默认
viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
@@ -913,16 +1198,65 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
viper.SetDefault("gateway.usage_record.worker_count", 128)
viper.SetDefault("gateway.usage_record.queue_size", 16384)
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512)
viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70)
viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15)
viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32)
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30)
viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15)
// TLS指纹伪装配置默认关闭需要账号级别单独启用
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
// Sora 直连配置
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
viper.SetDefault("sora.client.timeout_seconds", 120)
viper.SetDefault("sora.client.max_retries", 3)
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
viper.SetDefault("sora.client.poll_interval_seconds", 2)
viper.SetDefault("sora.client.max_poll_attempts", 600)
viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false)
viper.SetDefault("sora.client.use_openai_token_provider", false)
viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
viper.SetDefault("sora.storage.type", "local")
viper.SetDefault("sora.storage.local_path", "")
viper.SetDefault("sora.storage.fallback_to_upstream", true)
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
viper.SetDefault("sora.storage.debug", false)
viper.SetDefault("sora.storage.cleanup.enabled", true)
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新适配Google 1小时token
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
@@ -931,9 +1265,106 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
// Security - proxy fallback
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
// Subscription Maintenance (bounded queue + worker pool)
viper.SetDefault("subscription_maintenance.worker_count", 2)
viper.SetDefault("subscription_maintenance.queue_size", 1024)
}
func (c *Config) Validate() error {
jwtSecret := strings.TrimSpace(c.JWT.Secret)
if jwtSecret == "" {
return fmt.Errorf("jwt.secret is required")
}
// NOTE: 按 UTF-8 编码后的字节长度计算。
// 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。
if len([]byte(jwtSecret)) < 32 {
return fmt.Errorf("jwt.secret must be at least 32 bytes")
}
switch c.Log.Level {
case "debug", "info", "warn", "error":
case "":
return fmt.Errorf("log.level is required")
default:
return fmt.Errorf("log.level must be one of: debug/info/warn/error")
}
switch c.Log.Format {
case "json", "console":
case "":
return fmt.Errorf("log.format is required")
default:
return fmt.Errorf("log.format must be one of: json/console")
}
switch c.Log.StacktraceLevel {
case "none", "error", "fatal":
case "":
return fmt.Errorf("log.stacktrace_level is required")
default:
return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal")
}
if !c.Log.Output.ToStdout && !c.Log.Output.ToFile {
return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false")
}
if c.Log.Rotation.MaxSizeMB <= 0 {
return fmt.Errorf("log.rotation.max_size_mb must be positive")
}
if c.Log.Rotation.MaxBackups < 0 {
return fmt.Errorf("log.rotation.max_backups must be non-negative")
}
if c.Log.Rotation.MaxAgeDays < 0 {
return fmt.Errorf("log.rotation.max_age_days must be non-negative")
}
if c.Log.Sampling.Enabled {
if c.Log.Sampling.Initial <= 0 {
return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled")
}
if c.Log.Sampling.Thereafter <= 0 {
return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled")
}
} else {
if c.Log.Sampling.Initial < 0 {
return fmt.Errorf("log.sampling.initial must be non-negative")
}
if c.Log.Sampling.Thereafter < 0 {
return fmt.Errorf("log.sampling.thereafter must be non-negative")
}
}
if c.SubscriptionMaintenance.WorkerCount < 0 {
return fmt.Errorf("subscription_maintenance.worker_count must be non-negative")
}
if c.SubscriptionMaintenance.QueueSize < 0 {
return fmt.Errorf("subscription_maintenance.queue_size must be non-negative")
}
// Gemini OAuth 配置校验client_id 与 client_secret 必须同时设置或同时留空。
// 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。
geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID)
geminiClientSecret := strings.TrimSpace(c.Gemini.OAuth.ClientSecret)
if (geminiClientID == "") != (geminiClientSecret == "") {
return fmt.Errorf("gemini.oauth.client_id and gemini.oauth.client_secret must be both set or both empty")
}
if strings.TrimSpace(c.Server.FrontendURL) != "" {
if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL))
if err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
if u.RawQuery != "" || u.ForceQuery {
return fmt.Errorf("server.frontend_url invalid: must not include query")
}
if u.User != nil {
return fmt.Errorf("server.frontend_url invalid: must not include userinfo")
}
warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL)
}
if c.JWT.ExpireHour <= 0 {
return fmt.Errorf("jwt.expire_hour must be positive")
}
@@ -941,20 +1372,20 @@ func (c *Config) Validate() error {
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
}
if c.JWT.ExpireHour > 24 {
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour)
}
// JWT Refresh Token配置验证
if c.JWT.AccessTokenExpireMinutes <= 0 {
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
if c.JWT.AccessTokenExpireMinutes < 0 {
return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative")
}
if c.JWT.AccessTokenExpireMinutes > 720 {
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes)
}
if c.JWT.RefreshTokenExpireDays <= 0 {
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
}
if c.JWT.RefreshTokenExpireDays > 90 {
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays)
}
if c.JWT.RefreshWindowMinutes < 0 {
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
@@ -1160,9 +1591,116 @@ func (c *Config) Validate() error {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
}
}
if c.Idempotency.DefaultTTLSeconds <= 0 {
return fmt.Errorf("idempotency.default_ttl_seconds must be positive")
}
if c.Idempotency.SystemOperationTTLSeconds <= 0 {
return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive")
}
if c.Idempotency.ProcessingTimeoutSeconds <= 0 {
return fmt.Errorf("idempotency.processing_timeout_seconds must be positive")
}
if c.Idempotency.FailedRetryBackoffSeconds <= 0 {
return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive")
}
if c.Idempotency.MaxStoredResponseLen <= 0 {
return fmt.Errorf("idempotency.max_stored_response_len must be positive")
}
if c.Idempotency.CleanupIntervalSeconds <= 0 {
return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive")
}
if c.Idempotency.CleanupBatchSize <= 0 {
return fmt.Errorf("idempotency.cleanup_batch_size must be positive")
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
}
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
}
if c.Gateway.SoraMaxBodySize < 0 {
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
}
if c.Gateway.SoraStreamTimeoutSeconds < 0 {
return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
}
if c.Gateway.SoraRequestTimeoutSeconds < 0 {
return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
}
if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
}
if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
switch mode {
case "force", "error":
default:
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
}
}
if c.Sora.Client.TimeoutSeconds < 0 {
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
}
if c.Sora.Client.MaxRetries < 0 {
return fmt.Errorf("sora.client.max_retries must be non-negative")
}
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
}
if c.Sora.Client.PollIntervalSeconds < 0 {
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
}
if c.Sora.Client.MaxPollAttempts < 0 {
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
}
if c.Sora.Client.RecentTaskLimit < 0 {
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
}
if c.Sora.Client.RecentTaskLimitMax < 0 {
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
}
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
}
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
}
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
}
if !c.Sora.Client.CurlCFFISidecar.Enabled {
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
}
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
}
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
}
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
}
if c.Sora.Storage.MaxDownloadBytes < 0 {
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
}
if c.Sora.Storage.Cleanup.Enabled {
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
}
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
}
} else {
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
}
}
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
return fmt.Errorf("sora.storage.type must be 'local'")
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
@@ -1184,7 +1722,7 @@ func (c *Config) Validate() error {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.IdleConnTimeoutSeconds > 180 {
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds)
}
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
@@ -1215,6 +1753,70 @@ func (c *Config) Validate() error {
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
return fmt.Errorf("gateway.max_line_size must be at least 1MB")
}
if c.Gateway.UsageRecord.WorkerCount <= 0 {
return fmt.Errorf("gateway.usage_record.worker_count must be positive")
}
if c.Gateway.UsageRecord.QueueSize <= 0 {
return fmt.Errorf("gateway.usage_record.queue_size must be positive")
}
if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive")
}
switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) {
case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync:
default:
return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s",
UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync)
}
if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 {
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100")
}
if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) &&
c.Gateway.UsageRecord.OverflowSamplePercent <= 0 {
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample")
}
if c.Gateway.UsageRecord.AutoScaleEnabled {
if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive")
}
if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive")
}
if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers {
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers")
}
if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers ||
c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers {
return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers")
}
if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 {
return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100")
}
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 {
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99")
}
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent {
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent")
}
if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive")
}
if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive")
}
if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive")
}
if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative")
}
}
if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 {
return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive")
}
if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 {
return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
@@ -1421,6 +2023,6 @@ func warnIfInsecureURL(field, raw string) {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field)
}
}

View File

@@ -8,6 +8,25 @@ import (
"github.com/spf13/viper"
)
func resetViperWithJWTSecret(t *testing.T) {
t.Helper()
viper.Reset()
t.Setenv("JWT_SECRET", strings.Repeat("x", 32))
}
func TestLoadForBootstrapAllowsMissingJWTSecret(t *testing.T) {
viper.Reset()
t.Setenv("JWT_SECRET", "")
cfg, err := LoadForBootstrap()
if err != nil {
t.Fatalf("LoadForBootstrap() error: %v", err)
}
if cfg.JWT.Secret != "" {
t.Fatalf("LoadForBootstrap() should keep empty jwt.secret during bootstrap")
}
}
func TestNormalizeRunMode(t *testing.T) {
tests := []struct {
input string
@@ -29,7 +48,7 @@ func TestNormalizeRunMode(t *testing.T) {
}
func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -56,8 +75,44 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
}
}
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Idempotency.ObserveOnly {
t.Fatalf("Idempotency.ObserveOnly = false, want true")
}
if cfg.Idempotency.DefaultTTLSeconds != 86400 {
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds)
}
if cfg.Idempotency.SystemOperationTTLSeconds != 3600 {
t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds)
}
}
func TestLoadIdempotencyConfigFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false")
t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Idempotency.ObserveOnly {
t.Fatalf("Idempotency.ObserveOnly = true, want false")
}
if cfg.Idempotency.DefaultTTLSeconds != 600 {
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load()
@@ -71,7 +126,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
func TestLoadDefaultSecurityToggles(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -87,13 +142,69 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
}
if cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = true, want false")
if !cfg.Security.ResponseHeaders.Enabled {
t.Fatalf("ResponseHeaders.Enabled = false, want true")
}
}
func TestLoadDefaultServerMode(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Server.Mode != "release" {
t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release")
}
}
func TestLoadDefaultJWTAccessTokenExpireMinutes(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.JWT.ExpireHour != 24 {
t.Fatalf("JWT.ExpireHour = %d, want 24", cfg.JWT.ExpireHour)
}
if cfg.JWT.AccessTokenExpireMinutes != 0 {
t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 0", cfg.JWT.AccessTokenExpireMinutes)
}
}
func TestLoadJWTAccessTokenExpireMinutesFromEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "90")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.JWT.AccessTokenExpireMinutes != 90 {
t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 90", cfg.JWT.AccessTokenExpireMinutes)
}
}
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Database.SSLMode != "prefer" {
t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer")
}
}
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -118,7 +229,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -143,7 +254,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
}
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -168,7 +279,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
}
func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -188,7 +299,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
}
func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -207,7 +318,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
}
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -244,7 +355,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
}
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -263,7 +374,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
}
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -282,7 +393,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
}
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -307,7 +418,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
}
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -326,7 +437,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
}
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -424,6 +535,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
}
}
func TestValidateServerFrontendURL(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com/path"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() frontend_url with path valid error: %v", err)
}
cfg.Server.FrontendURL = "https://example.com?utm=1"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with query")
}
cfg.Server.FrontendURL = "https://user:pass@example.com"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject server.frontend_url with userinfo")
}
cfg.Server.FrontendURL = "/relative"
if err := cfg.Validate(); err == nil {
t.Fatalf("Validate() should reject relative server.frontend_url")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
@@ -445,6 +590,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) {
func TestWarnIfInsecureURL(t *testing.T) {
warnIfInsecureURL("test", "http://example.com")
warnIfInsecureURL("test", "bad://url")
warnIfInsecureURL("test", "://invalid")
}
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
@@ -458,7 +604,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) {
}
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -476,7 +622,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
}
func TestValidateConcurrencyPingInterval(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -493,14 +639,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) {
}
func TestProvideConfig(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
if _, err := ProvideConfig(); err != nil {
t.Fatalf("ProvideConfig() error: %v", err)
}
}
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
@@ -544,6 +690,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) {
}
}
func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) {
d := &DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "u",
Password: "p",
DBName: "db",
SSLMode: "prefer",
}
got := d.DSNWithTimezone("UTC")
if !strings.Contains(got, "password=p") {
t.Fatalf("DSNWithTimezone should include password: %q", got)
}
if !strings.Contains(got, "TimeZone=UTC") {
t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got)
}
}
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
@@ -566,10 +730,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) {
warnIfInsecureURL("secure", "https://example.com")
}
func TestValidateJWTSecret_UTF8Bytes(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
// 31 bytes (< 32) even though it's 31 characters.
cfg.JWT.Secret = strings.Repeat("a", 31)
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() should reject 31-byte secret")
}
if !strings.Contains(err.Error(), "at least 32 bytes") {
t.Fatalf("Validate() error = %v", err)
}
// 32 bytes OK.
cfg.JWT.Secret = strings.Repeat("a", 32)
err = cfg.Validate()
if err != nil {
t.Fatalf("Validate() should accept 32-byte secret: %v", err)
}
}
func TestValidateConfigErrors(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
viper.Reset()
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
@@ -582,6 +771,26 @@ func TestValidateConfigErrors(t *testing.T) {
mutate func(*Config)
wantErr string
}{
{
name: "jwt secret required",
mutate: func(c *Config) { c.JWT.Secret = "" },
wantErr: "jwt.secret is required",
},
{
name: "jwt secret min bytes",
mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) },
wantErr: "jwt.secret must be at least 32 bytes",
},
{
name: "subscription maintenance worker_count non-negative",
mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 },
wantErr: "subscription_maintenance.worker_count",
},
{
name: "subscription maintenance queue_size non-negative",
mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 },
wantErr: "subscription_maintenance.queue_size",
},
{
name: "jwt expire hour positive",
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
@@ -592,6 +801,11 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
wantErr: "jwt.expire_hour must be <= 168",
},
{
name: "jwt access token expire minutes non-negative",
mutate: func(c *Config) { c.JWT.AccessTokenExpireMinutes = -1 },
wantErr: "jwt.access_token_expire_minutes must be non-negative",
},
{
name: "csp policy required",
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
@@ -799,6 +1013,84 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
wantErr: "gateway.max_line_size must be non-negative",
},
{
name: "gateway usage record worker count",
mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 },
wantErr: "gateway.usage_record.worker_count",
},
{
name: "gateway usage record queue size",
mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 },
wantErr: "gateway.usage_record.queue_size",
},
{
name: "gateway usage record timeout",
mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 },
wantErr: "gateway.usage_record.task_timeout_seconds",
},
{
name: "gateway usage record overflow policy",
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" },
wantErr: "gateway.usage_record.overflow_policy",
},
{
name: "gateway usage record sample percent range",
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 },
wantErr: "gateway.usage_record.overflow_sample_percent",
},
{
name: "gateway usage record sample percent required for sample policy",
mutate: func(c *Config) {
c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample
c.Gateway.UsageRecord.OverflowSamplePercent = 0
},
wantErr: "gateway.usage_record.overflow_sample_percent must be positive",
},
{
name: "gateway usage record auto scale max gte min",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleMinWorkers = 256
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128
},
wantErr: "gateway.usage_record.auto_scale_max_workers",
},
{
name: "gateway usage record worker in auto scale range",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleMinWorkers = 200
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300
c.Gateway.UsageRecord.WorkerCount = 128
},
wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers",
},
{
name: "gateway usage record auto scale queue thresholds order",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50
c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50
},
wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less",
},
{
name: "gateway usage record auto scale up step",
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 },
wantErr: "gateway.usage_record.auto_scale_up_step",
},
{
name: "gateway usage record auto scale interval",
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 },
wantErr: "gateway.usage_record.auto_scale_check_interval_seconds",
},
{
name: "gateway user group rate cache ttl",
mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 },
wantErr: "gateway.user_group_rate_cache_ttl_seconds",
},
{
name: "gateway models list cache ttl range",
mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 },
wantErr: "gateway.models_list_cache_ttl_seconds",
},
{
name: "gateway scheduling sticky waiting",
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
@@ -822,6 +1114,37 @@ func TestValidateConfigErrors(t *testing.T) {
},
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
},
{
name: "log level invalid",
mutate: func(c *Config) { c.Log.Level = "trace" },
wantErr: "log.level",
},
{
name: "log format invalid",
mutate: func(c *Config) { c.Log.Format = "plain" },
wantErr: "log.format",
},
{
name: "log output disabled",
mutate: func(c *Config) {
c.Log.Output.ToStdout = false
c.Log.Output.ToFile = false
},
wantErr: "log.output.to_stdout and log.output.to_file cannot both be false",
},
{
name: "log rotation size",
mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 },
wantErr: "log.rotation.max_size_mb",
},
{
name: "log sampling enabled invalid",
mutate: func(c *Config) {
c.Log.Sampling.Enabled = true
c.Log.Sampling.Initial = 0
},
wantErr: "log.sampling.initial",
},
{
name: "ops metrics collector ttl",
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
@@ -850,3 +1173,234 @@ func TestValidateConfigErrors(t *testing.T) {
})
}
}
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
cfg.Gateway.UsageRecord.WorkerCount = 64
// 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。
cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0
cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0
cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0
cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100
cfg.Gateway.UsageRecord.AutoScaleUpStep = 0
cfg.Gateway.UsageRecord.AutoScaleDownStep = 0
cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0
cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err)
}
}
func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
resetViperWithJWTSecret(t)
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "log level required",
mutate: func(c *Config) {
c.Log.Level = ""
},
wantErr: "log.level is required",
},
{
name: "log format required",
mutate: func(c *Config) {
c.Log.Format = ""
},
wantErr: "log.format is required",
},
{
name: "log stacktrace required",
mutate: func(c *Config) {
c.Log.StacktraceLevel = ""
},
wantErr: "log.stacktrace_level is required",
},
{
name: "log max backups non-negative",
mutate: func(c *Config) {
c.Log.Rotation.MaxBackups = -1
},
wantErr: "log.rotation.max_backups must be non-negative",
},
{
name: "log max age non-negative",
mutate: func(c *Config) {
c.Log.Rotation.MaxAgeDays = -1
},
wantErr: "log.rotation.max_age_days must be non-negative",
},
{
name: "sampling thereafter non-negative when disabled",
mutate: func(c *Config) {
c.Log.Sampling.Enabled = false
c.Log.Sampling.Thereafter = -1
},
wantErr: "log.sampling.thereafter must be non-negative",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
tt.mutate(cfg)
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
}
})
}
}
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
}
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
}
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
}
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
}
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
}
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
}
}
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
}
}
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
}
}
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
}
}
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
}
}
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.UsageRecord.WorkerCount != 128 {
t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount)
}
if cfg.Gateway.UsageRecord.QueueSize != 16384 {
t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize)
}
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 {
t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds)
}
if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample)
}
if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 {
t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent)
}
if !cfg.Gateway.UsageRecord.AutoScaleEnabled {
t.Fatalf("auto_scale_enabled = false, want true")
}
if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 {
t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers)
}
if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 {
t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers)
}
if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 {
t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent)
}
if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 {
t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent)
}
if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 {
t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep)
}
if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 {
t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep)
}
if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 {
t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds)
}
if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 {
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
}
}

View File

@@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet(
// ProvideConfig 提供应用配置
func ProvideConfig() (*Config, error) {
return Load()
return LoadForBootstrap()
}

View File

@@ -22,6 +22,7 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformSora = "sora"
)
// Account type constants

View File

@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
return
}
dataPayload := req.Data
if err := validateDataHeader(dataPayload); err != nil {
if err := validateDataHeader(req.Data); err != nil {
response.BadRequest(c, err.Error())
return
}
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
return h.importData(ctx, req)
})
}
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind
}
dataPayload := req.Data
result := DataImportResult{}
existingProxies, err := h.listAllProxies(c.Request.Context())
existingProxies, err := h.listAllProxies(ctx)
if err != nil {
response.ErrorFrom(c, err)
return
return result, err
}
proxyKeyToID := make(map[string]int64, len(existingProxies))
@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
proxyKeyToID[key] = existingID
result.ProxyReused++
if normalizedStatus != "" {
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
continue
}
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
Username: item.Username,
Password: item.Password,
})
if err != nil {
if createErr != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
Message: createErr.Error(),
})
continue
}
@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
_, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
SkipDefaultGroupBind: skipDefaultGroupBind,
}
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.AccountCreated++
}
response.Success(c, result)
return result, nil
}
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {

View File

@@ -2,7 +2,13 @@
package admin
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
@@ -149,6 +155,44 @@ type AccountWithConcurrency struct {
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
}
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
item := AccountWithConcurrency{
Account: dto.AccountFromService(account),
CurrentConcurrency: 0,
}
if account == nil {
return item
}
if h.concurrencyService != nil {
if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil {
item.CurrentConcurrency = counts[account.ID]
}
}
if account.IsAnthropicOAuthOrSetupToken() {
if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 {
startTime := account.GetCurrentWindowStartTime()
if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil {
cost := stats.StandardCost
item.CurrentWindowCost = &cost
}
}
if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 {
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout}
if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil {
if count, ok := sessions[account.ID]; ok {
item.ActiveSessions = &count
}
}
}
}
return item
}
// List handles listing all accounts with pagination
// GET /api/v1/admin/accounts
func (h *AccountHandler) List(c *gin.Context) {
@@ -269,9 +313,71 @@ func (h *AccountHandler) List(c *gin.Context) {
result[i] = item
}
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
if etag != "" {
c.Header("ETag", etag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) {
c.Status(http.StatusNotModified)
return
}
}
response.Paginated(c, result, total, page, pageSize)
}
func buildAccountsListETag(
items []AccountWithConcurrency,
total int64,
page, pageSize int,
platform, accountType, status, search string,
) string {
payload := struct {
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Platform string `json:"platform"`
AccountType string `json:"type"`
Status string `json:"status"`
Search string `json:"search"`
Items []AccountWithConcurrency `json:"items"`
}{
Total: total,
Page: page,
PageSize: pageSize,
Platform: platform,
AccountType: accountType,
Status: status,
Search: search,
Items: items,
}
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func ifNoneMatchMatched(ifNoneMatch, etag string) bool {
if etag == "" || ifNoneMatch == "" {
return false
}
for _, token := range strings.Split(ifNoneMatch, ",") {
candidate := strings.TrimSpace(token)
if candidate == "*" {
return true
}
if candidate == etag {
return true
}
if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag {
return true
}
}
return false
}
// GetByID handles getting an account by ID
// GET /api/v1/admin/accounts/:id
func (h *AccountHandler) GetByID(c *gin.Context) {
@@ -287,7 +393,51 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
var req CheckMixedChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.GroupIDs) == 0 {
response.Success(c, gin.H{"has_risk": false})
return
}
accountID := int64(0)
if req.AccountID != nil {
accountID = *req.AccountID
}
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
response.Success(c, gin.H{
"has_risk": true,
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"has_risk": false})
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
@@ -350,21 +500,27 @@ func (h *AccountHandler) Create(c *gin.Context) {
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if execErr != nil {
return nil, execErr
}
return h.buildAccountResponseWithRuntime(ctx, account), nil
})
if err != nil {
// 检查是否为混合渠道错误
@@ -378,11 +534,17 @@ func (h *AccountHandler) Create(c *gin.Context) {
return
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(account))
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}
// Update handles updating an account
@@ -439,7 +601,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// Delete handles deleting an account
@@ -697,7 +859,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
response.Success(c, dto.AccountFromService(updatedAccount))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
}
// GetStats handles getting account statistics
@@ -755,7 +917,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
}
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// BatchCreate handles batch creating accounts
@@ -769,61 +931,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return
}
ctx := c.Request.Context()
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
})
continue
}
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
"id": account.ID,
"success": true,
})
continue
}
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"id": account.ID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
return gin.H{
"success": success,
"failed": failed,
"results": results,
}, nil
})
}
@@ -861,57 +1024,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
return
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
// 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试
success := 0
failed := 0
successIDs := make([]int64, 0, len(updates))
failedIDs := make([]int64, 0, len(updates))
results := make([]gin.H, 0, len(updates))
for _, u := range updates {
updateInput := &service.UpdateAccountInput{Credentials: u.Credentials}
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
failed++
failedIDs = append(failedIDs, u.ID)
results = append(results, gin.H{
"account_id": accountID,
"account_id": u.ID,
"success": false,
"error": err.Error(),
})
continue
}
success++
successIDs = append(successIDs, u.ID)
results = append(results, gin.H{
"account_id": accountID,
"account_id": u.ID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
"success": success,
"failed": failed,
"success_ids": successIDs,
"failed_ids": failedIDs,
"results": results,
})
}
@@ -1146,7 +1310,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
return
}
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetTempUnschedulable handles getting temporary unschedulable status
@@ -1236,7 +1406,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetAvailableModels handles getting available models for an account
@@ -1362,6 +1532,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {

View File

@@ -0,0 +1,66 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) {
gin.SetMode(gin.TestMode)
adminSvc := newStubAdminService()
handler := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router := gin.New()
router.POST("/api/v1/admin/accounts", handler.Create)
body := map[string]any{
"name": "anthropic-key-1",
"platform": "anthropic",
"type": "apikey",
"credentials": map[string]any{
"api_key": "sk-ant-xxx",
"base_url": "https://api.anthropic.com",
},
"extra": map[string]any{
"anthropic_passthrough": true,
},
"concurrency": 1,
"priority": 1,
}
raw, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdAccounts, 1)
created := adminSvc.createdAccounts[0]
require.Equal(t, "anthropic", created.Platform)
require.Equal(t, "apikey", created.Type)
require.NotNil(t, created.Extra)
require.Equal(t, true, created.Extra["anthropic_passthrough"])
}

View File

@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)

View File

@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) {
require.False(t, ok)
}
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
tests := []struct {
input string
want time.Duration
ok bool
}{
{input: "30m", want: 30 * time.Minute, ok: true},
{input: "1h", want: time.Hour, ok: true},
{input: "1d", want: 24 * time.Hour, ok: true},
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
{input: "7d", want: 0, ok: false},
}
for _, tt := range tests {
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
require.Equal(t, tt.want, got, "input=%s", tt.input)
}
}
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
before := time.Now().UTC()
filter, err := parseOpsOpenAITokenStatsFilter(c)
after := time.Now().UTC()
require.NoError(t, err)
require.NotNil(t, filter)
require.Equal(t, "30d", filter.TimeRange)
require.Equal(t, 1, filter.Page)
require.Equal(t, 20, filter.PageSize)
require.Equal(t, 0, filter.TopN)
require.Nil(t, filter.GroupID)
require.Equal(t, "", filter.Platform)
require.True(t, filter.StartTime.Before(filter.EndTime))
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
}
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(
http.MethodGet,
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
nil,
)
filter, err := parseOpsOpenAITokenStatsFilter(c)
require.NoError(t, err)
require.Equal(t, "1h", filter.TimeRange)
require.Equal(t, "openai", filter.Platform)
require.NotNil(t, filter.GroupID)
require.Equal(t, int64(12), *filter.GroupID)
require.Equal(t, 50, filter.TopN)
require.Equal(t, 0, filter.Page)
require.Equal(t, 0, filter.PageSize)
}
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
tests := []string{
"/?time_range=7d",
"/?group_id=0",
"/?group_id=abc",
"/?top_n=0",
"/?top_n=101",
"/?top_n=10&page=1",
"/?top_n=10&page_size=20",
"/?page=0",
"/?page_size=0",
"/?page_size=101",
}
gin.SetMode(gin.TestMode)
for _, rawURL := range tests {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
_, err := parseOpsOpenAITokenStatsFilter(c)
require.Error(t, err, "url=%s", rawURL)
}
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()

View File

@@ -348,6 +348,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
return &service.ProxyQualityCheckResult{
ProxyID: id,
Score: 95,
Grade: "A",
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
PassedCount: 5,
WarnCount: 0,
FailedCount: 0,
ChallengeCount: 0,
CheckedAt: time.Now().Unix(),
Items: []service.ProxyQualityCheckItem{
{Target: "base_connectivity", Status: "pass", Message: "ok"},
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
{Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}

View File

@@ -0,0 +1,208 @@
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_PartialFailure(t *testing.T) {
// 让第 2 个账号ID=2更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细")
var resp map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data := resp["data"].(map[string]any)
require.Equal(t, float64(2), data["success"], "应有 2 个成功")
require.Equal(t, float64(1), data["failed"], "应有 1 个失败")
// 所有 3 个账号都会被尝试更新(非 fail-fast
require.Equal(t, int64(3), svc.updateCallCount.Load(),
"应调用 3 次 UpdateAccount逐个尝试失败后继续")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型string应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型number应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null设置为空应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}

View File

@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
if strings.Contains(msg, "OAuth client not configured") ||
strings.Contains(msg, "requires your own OAuth Client") ||
strings.Contains(msg, "requires a custom OAuth Client") ||
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}

View File

@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -55,7 +59,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -67,6 +71,10 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -179,6 +187,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
@@ -225,6 +237,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,

View File

@@ -0,0 +1,115 @@
package admin
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type idempotencyStoreUnavailableMode int
const (
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
idempotencyStoreUnavailableFailOpen
)
func executeAdminIdempotent(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) (*service.IdempotencyExecuteResult, error) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
return nil, err
}
return &service.IdempotencyExecuteResult{Data: data}, nil
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
}
func executeAdminIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
}
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
}
func executeAdminIdempotentJSONWithMode(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
mode idempotencyStoreUnavailableMode,
execute func(context.Context) (any, error),
) {
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
strategy := "fail_close"
if mode == idempotencyStoreUnavailableFailOpen {
strategy = "fail_open"
}
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
if mode == idempotencyStoreUnavailableFailOpen {
data, fallbackErr := execute(c.Request.Context())
if fallbackErr != nil {
response.ErrorFrom(c, fallbackErr)
return
}
c.Header("X-Idempotency-Degraded", "store-unavailable")
response.Success(c, data)
return
}
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}

View File

@@ -0,0 +1,285 @@
package admin
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type storeUnavailableRepoStub struct{}
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
}
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-2")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
}
type memoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
return &memoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(120 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
status1, _ = call()
}()
go func() {
defer wg.Done()
status2, _ = call()
}()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}

View File

@@ -2,6 +2,7 @@ package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService
}
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string
if req.ProxyID != nil {
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
}
}
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil {
response.ErrorFrom(c, err)
return
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI account
// ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return
}
// Ensure account is OpenAI platform
if !account.IsOpenAI() {
response.BadRequest(c, "Account is not an OpenAI account")
platform := oauthPlatformFromPath(c)
if account.Platform != platform {
response.BadRequest(c, "Account platform does not match OAuth endpoint")
return
}
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount))
}
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
name = "OpenAI OAuth Account"
if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account"
}
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
Platform: "openai",
Platform: platform,
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,

View File

@@ -1,6 +1,7 @@
package admin
import (
"fmt"
"net/http"
"strconv"
"strings"
@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
response.Success(c, data)
}
// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model.
// GET /api/v1/admin/ops/dashboard/openai-token-stats
func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
filter, err := parseOpsOpenAITokenStatsFilter(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) {
if c == nil {
return nil, fmt.Errorf("invalid request")
}
timeRange := strings.TrimSpace(c.Query("time_range"))
if timeRange == "" {
timeRange = "30d"
}
dur, ok := parseOpsOpenAITokenStatsDuration(timeRange)
if !ok {
return nil, fmt.Errorf("invalid time_range")
}
end := time.Now().UTC()
start := end.Add(-dur)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: timeRange,
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(c.Query("platform")),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid group_id")
}
filter.GroupID = &id
}
topNRaw := strings.TrimSpace(c.Query("top_n"))
pageRaw := strings.TrimSpace(c.Query("page"))
pageSizeRaw := strings.TrimSpace(c.Query("page_size"))
if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") {
return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size")
}
if topNRaw != "" {
topN, err := strconv.Atoi(topNRaw)
if err != nil || topN < 1 || topN > 100 {
return nil, fmt.Errorf("invalid top_n")
}
filter.TopN = topN
return filter, nil
}
filter.Page = 1
filter.PageSize = 20
if pageRaw != "" {
page, err := strconv.Atoi(pageRaw)
if err != nil || page < 1 {
return nil, fmt.Errorf("invalid page")
}
filter.Page = page
}
if pageSizeRaw != "" {
pageSize, err := strconv.Atoi(pageSizeRaw)
if err != nil || pageSize < 1 || pageSize > 100 {
return nil, fmt.Errorf("invalid page_size")
}
filter.PageSize = pageSize
}
return filter, nil
}
func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) {
switch strings.TrimSpace(v) {
case "30m":
return 30 * time.Minute, true
case "1h":
return time.Hour, true
case "1d":
return 24 * time.Hour, true
case "15d":
return 15 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}
}
func pickThroughputBucketSeconds(window time.Duration) int {
// Keep buckets predictable and avoid huge responses.
switch {

View File

@@ -0,0 +1,173 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type testSettingRepo struct {
values map[string]string
}
func newTestSettingRepo() *testSettingRepo {
return &testSettingRepo{values: map[string]string{}}
}
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
v, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &service.Setting{Key: key, Value: v}, nil
}
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
v, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
s.values[key] = value
return nil
}
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, k := range keys {
if v, ok := s.values[k]; ok {
out[k] = v
}
}
return out, nil
}
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
s.values[k] = v
}
return nil
}
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for k, v := range s.values {
out[k] = v
}
return out, nil
}
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
delete(s.values, key)
return nil
}
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
c.Next()
})
}
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
return r
}
func newRuntimeOpsService(t *testing.T) *service.OpsService {
t.Helper()
if err := logger.Init(logger.InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
settingRepo := newTestSettingRepo()
cfg := &config.Config{
Ops: config.OpsConfig{Enabled: true},
Log: config.LogConfig{
Level: "info",
Caller: true,
StacktraceLevel: "error",
Sampling: config.LogSamplingConfig{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
},
}
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
}
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, true)
payload := map[string]any{
"level": "debug",
"enable_sampling": false,
"sampling_initial": 100,
"sampling_thereafter": 100,
"caller": true,
"stacktrace_level": "error",
"retention_days": 30,
}
raw, _ := json.Marshal(payload)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
}
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
}
}

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
response.Success(c, updated)
}
// GetRuntimeLogConfig returns runtime log config (DB-backed).
// GET /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
return
}
response.Success(c, cfg)
}
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
// PUT /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsRuntimeLogConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
// POST /api/v1/admin/ops/runtime/logging/reset
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {

View File

@@ -0,0 +1,174 @@
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type opsSystemLogCleanupRequest struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Level string `json:"level"`
Component string `json:"component"`
RequestID string `json:"request_id"`
ClientRequestID string `json:"client_request_id"`
UserID *int64 `json:"user_id"`
AccountID *int64 `json:"account_id"`
Platform string `json:"platform"`
Model string `json:"model"`
Query string `json:"q"`
}
// ListSystemLogs returns indexed system logs.
// GET /api/v1/admin/ops/system-logs
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 200 {
pageSize = 200
}
start, end, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsSystemLogFilter{
Page: page,
PageSize: pageSize,
StartTime: &start,
EndTime: &end,
Level: strings.TrimSpace(c.Query("level")),
Component: strings.TrimSpace(c.Query("component")),
RequestID: strings.TrimSpace(c.Query("request_id")),
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
Platform: strings.TrimSpace(c.Query("platform")),
Model: strings.TrimSpace(c.Query("model")),
Query: strings.TrimSpace(c.Query("q")),
}
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
}
// CleanupSystemLogs deletes indexed system logs by filter.
// POST /api/v1/admin/ops/system-logs/cleanup
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
var req opsSystemLogCleanupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
parseTS := func(raw string) (*time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
return &t, nil
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return nil, err
}
return &t, nil
}
start, err := parseTS(req.StartTime)
if err != nil {
response.BadRequest(c, "Invalid start_time")
return
}
end, err := parseTS(req.EndTime)
if err != nil {
response.BadRequest(c, "Invalid end_time")
return
}
filter := &service.OpsSystemLogCleanupFilter{
StartTime: start,
EndTime: end,
Level: strings.TrimSpace(req.Level),
Component: strings.TrimSpace(req.Component),
RequestID: strings.TrimSpace(req.RequestID),
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
UserID: req.UserID,
AccountID: req.AccountID,
Platform: strings.TrimSpace(req.Platform),
Model: strings.TrimSpace(req.Model),
Query: strings.TrimSpace(req.Query),
}
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": deleted})
}
// GetSystemLogIngestionHealth returns sink health metrics.
// GET /api/v1/admin/ops/system-logs/health
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.opsService.GetSystemLogSinkHealth())
}

View File

@@ -0,0 +1,233 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type responseEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
c.Next()
})
}
r.GET("/logs", handler.ListSystemLogs)
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
return r
}
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
var resp responseEnvelope
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if resp.Code != 0 {
t.Fatalf("unexpected response code: %+v", resp)
}
}
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc)
r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}

View File

@@ -3,7 +3,6 @@ package admin
import (
"context"
"encoding/json"
"log"
"math"
"net"
"net/http"
@@ -16,6 +15,7 @@ import (
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -252,7 +252,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
if err != nil || stats == nil {
if err != nil {
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
}
return
}
@@ -278,7 +278,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
msg, err := json.Marshal(payload)
if err != nil {
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
return
}
@@ -338,7 +338,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
// Reserve a global slot before upgrading the connection to keep the limit strict.
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
@@ -350,7 +350,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
@@ -359,7 +359,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("[OpsWS] upgrade failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
return
}
@@ -452,7 +452,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
log.Printf("[OpsWS] set read deadline failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
@@ -471,7 +471,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Printf("[OpsWS] read failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
}
return
}
@@ -508,7 +508,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
continue
}
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
log.Printf("[OpsWS] write failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
cancel()
closeConn()
wg.Wait()
@@ -517,7 +517,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
log.Printf("[OpsWS] ping failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
cancel()
closeConn()
wg.Wait()
@@ -666,14 +666,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed
} else {
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
}
}
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 {
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
}
cfg.TrustedProxies = prefixes
}
@@ -684,7 +684,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized
default:
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
}
}
@@ -701,14 +701,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
cfg.MaxConns = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
}
}
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
cfg.MaxConnsPerIP = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
}
}
return cfg

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"strings"
@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
return
}
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
return nil, err
}
return dto.ProxyFromService(proxy), nil
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.ProxyFromService(proxy))
}
// Update handles updating a proxy
@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
response.Success(c, result)
}
// CheckQuality handles checking proxy quality across common AI targets.
// POST /api/v1/admin/proxies/:id/quality-check
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) {

View File

@@ -2,6 +2,7 @@ package admin
import (
"bytes"
"context"
"encoding/csv"
"fmt"
"strconv"
@@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if execErr != nil {
return nil, execErr
}
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Success(c, out)
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
return out, nil
})
}
// Delete handles deleting a redeem code

View File

@@ -0,0 +1,97 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
SubscriptionID int64 `json:"subscription_id"`
Body AdjustSubscriptionRequest `json:"body"`
}{
SubscriptionID: subscriptionID,
Body: req,
}
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
if execErr != nil {
return nil, execErr
}
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
})
}
// Revoke handles revoking a subscription

View File

@@ -1,11 +1,15 @@
package admin
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -14,12 +18,14 @@ import (
// SystemHandler handles system-related operations
type SystemHandler struct {
updateSvc *service.UpdateService
lockSvc *service.SystemOperationLockService
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
return &SystemHandler{
updateSvc: updateSvc,
lockSvc: lockSvc,
}
}
@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
// PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
operationID := buildSystemOperationID(c, "update")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
})
}
// Rollback restores the previous version
// POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
operationID := buildSystemOperationID(c, "rollback")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.Rollback(); err != nil {
releaseReason = "SYSTEM_ROLLBACK_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
})
}
// RestartService restarts the systemd service
// POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) {
// Schedule service restart in background after sending response
// This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
operationID := buildSystemOperationID(c, "restart")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
succeeded := false
defer func() {
release("", succeeded)
}()
response.Success(c, gin.H{
"message": "Service restart initiated",
// Schedule service restart in background after sending response
// This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
succeeded = true
return gin.H{
"message": "Service restart initiated",
"operation_id": lock.OperationID(),
}, nil
})
}
func (h *SystemHandler) acquireSystemLock(
ctx context.Context,
operationID string,
) (*service.SystemOperationLock, func(string, bool), error) {
if h.lockSvc == nil {
return nil, nil, service.ErrIdempotencyStoreUnavail
}
lock, err := h.lockSvc.Acquire(ctx, operationID)
if err != nil {
return nil, nil, err
}
release := func(reason string, succeeded bool) {
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
}
return lock, release, nil
}
func buildSystemOperationID(c *gin.Context, operation string) string {
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
if key == "" {
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
hash := service.HashIdempotencyKey(seed)
if len(hash) > 24 {
hash = hash[:24]
}
return "sysop-" + hash
}

View File

@@ -1,13 +1,14 @@
package admin
import (
"log"
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -378,11 +379,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
@@ -390,7 +391,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
@@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
billingType = *filters.BillingType
}
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
OperatorID int64 `json:"operator_id"`
Body CreateUsageCleanupTaskRequest `json:"body"`
}{
OperatorID: subject.UserID,
Body: req,
}
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task))
task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
if err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
return nil, err
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
return dto.UsageCleanupTaskFromService(task), nil
})
}
// CancelCleanupTask handles canceling a usage cleanup task
@@ -515,12 +523,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
response.BadRequest(c, "Invalid task id")
return
}
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"strings"
@@ -78,8 +79,8 @@ func (h *UserHandler) List(c *gin.Context) {
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
if runes := []rune(search); len(runes) > 100 {
search = string(runes[:100])
}
filters := service.UserListFilters{
@@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil {
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
UserID int64 `json:"user_id"`
Body UpdateBalanceRequest `json:"body"`
}{
UserID: userID,
Body: req,
}
response.Success(c, dto.UserFromServiceAdmin(user))
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
if execErr != nil {
return nil, execErr
}
return dto.UserFromServiceAdmin(user), nil
})
}
// GetUserAPIKeys handles getting user's API keys

View File

@@ -2,6 +2,7 @@
package handler
import (
"context"
"strconv"
"time"
@@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
if req.Quota != nil {
svcReq.Quota = *req.Quota
}
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.APIKeyFromService(key))
executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
if err != nil {
return nil, err
}
return dto.APIKeyFromService(key), nil
})
}
// Update handles updating an API key

View File

@@ -2,6 +2,7 @@ package handler
import (
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -112,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
// Turnstile 验证 — 始终执行,防止绕过
// TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
@@ -448,17 +448,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
// Build frontend base URL from request
scheme := "https"
if c.Request.TLS == nil {
// Check X-Forwarded-Proto header (common in reverse proxy setups)
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
scheme = proto
} else {
scheme = "http"
}
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
if frontendBaseURL == "" {
slog.Error("server.frontend_url not configured; cannot build password reset link")
response.InternalError(c, "Password reset is not configured")
return
}
frontendBaseURL := scheme + "://" + c.Request.Host
// Request password reset (async)
// Note: This returns success even if email doesn't exist (to prevent enumeration)

View File

@@ -0,0 +1,40 @@
package dto
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestAPIKeyFromService_MapsLastUsedAt(t *testing.T) {
lastUsed := time.Now().UTC().Truncate(time.Second)
src := &service.APIKey{
ID: 1,
UserID: 2,
Key: "sk-map-last-used",
Name: "Mapper",
Status: service.StatusActive,
LastUsedAt: &lastUsed,
}
out := APIKeyFromService(src)
require.NotNil(t, out)
require.NotNil(t, out.LastUsedAt)
require.WithinDuration(t, lastUsed, *out.LastUsedAt, time.Second)
}
func TestAPIKeyFromService_MapsNilLastUsedAt(t *testing.T) {
src := &service.APIKey{
ID: 1,
UserID: 2,
Key: "sk-map-last-used-nil",
Name: "MapperNil",
Status: service.StatusActive,
}
out := APIKeyFromService(src)
require.NotNil(t, out)
require.Nil(t, out.LastUsedAt)
}

View File

@@ -2,6 +2,7 @@
package dto
import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -77,6 +78,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Status: k.Status,
IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist,
LastUsedAt: k.LastUsedAt,
Quota: k.Quota,
QuotaUsed: k.QuotaUsed,
ExpiresAt: k.ExpiresAt,
@@ -129,23 +131,26 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
func groupFromServiceBase(g *service.Group) Group {
return Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
// 无效请求兜底分组
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
SoraImagePrice360: g.SoraImagePrice360,
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
@@ -300,6 +305,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
CountryCode: p.CountryCode,
Region: p.Region,
City: p.City,
QualityStatus: p.QualityStatus,
QualityScore: p.QualityScore,
QualityGrade: p.QualityGrade,
QualitySummary: p.QualitySummary,
QualityChecked: p.QualityChecked,
}
}
@@ -404,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
ImageSize: l.ImageSize,
MediaType: l.MediaType,
UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,
CreatedAt: l.CreatedAt,
@@ -532,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
}
statuses := make(map[string]string, len(r.Statuses))
for userID, status := range r.Statuses {
statuses[strconv.FormatInt(userID, 10)] = status
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,
CreatedCount: r.CreatedCount,
ReusedCount: r.ReusedCount,
FailedCount: r.FailedCount,
Subscriptions: subs,
Errors: r.Errors,
Statuses: statuses,
}
}

View File

@@ -38,6 +38,7 @@ type APIKey struct {
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"`
IPBlacklist []string `json:"ip_blacklist"`
LastUsedAt *time.Time `json:"last_used_at"`
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires)
@@ -67,6 +68,12 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
// Sora 按次计费配置
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
@@ -196,6 +203,11 @@ type ProxyWithAccountCount struct {
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
QualityStatus string `json:"quality_status,omitempty"`
QualityScore *int `json:"quality_score,omitempty"`
QualityGrade string `json:"quality_grade,omitempty"`
QualitySummary string `json:"quality_summary,omitempty"`
QualityChecked *int64 `json:"quality_checked,omitempty"`
}
type ProxyAccountSummary struct {
@@ -274,6 +286,7 @@ type UsageLog struct {
// 图片生成字段
ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"`
MediaType *string `json:"media_type"`
// User-Agent
UserAgent *string `json:"user_agent"`
@@ -382,9 +395,12 @@ type AdminUserSubscription struct {
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
CreatedCount int `json:"created_count"`
ReusedCount int `json:"reused_count"`
FailedCount int `json:"failed_count"`
Subscriptions []AdminUserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
Statuses map[string]string `json:"statuses,omitempty"`
}
// PromoCode 注册优惠码

View File

@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
@@ -19,11 +18,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// GatewayHandler handles API gateway requests
@@ -35,10 +36,12 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService
usageService *service.UsageService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
cfg *config.Config
}
// NewGatewayHandler creates a new GatewayHandler
@@ -51,6 +54,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *GatewayHandler {
@@ -74,10 +78,12 @@ func NewGatewayHandler(
billingCacheService: billingCacheService,
usageService: usageService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg,
}
}
@@ -96,6 +102,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.messages",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
@@ -122,6 +135,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
@@ -161,9 +175,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err))
// On error, allow request to proceed
} else if !canWait {
reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
@@ -180,7 +195,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
@@ -197,7 +212,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
@@ -227,6 +242,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
c.Request = c.Request.WithContext(ctx)
}
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
@@ -266,7 +290,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
@@ -294,21 +318,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gateway.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -319,17 +346,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
@@ -363,8 +388,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
@@ -372,22 +401,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fcb,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, fs.ForceCacheBilling)
return
@@ -439,7 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
@@ -467,20 +501,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gateway.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -491,16 +529,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
@@ -523,18 +560,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var promptTooLongErr *service.PromptTooLongError
if errors.As(err, &promptTooLongErr) {
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
zap.Any("current_group_id", currentAPIKey.GroupID),
zap.Any("fallback_group_id", fallbackGroupID),
zap.Bool("fallback_used", fallbackUsed),
)
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
if err != nil {
log.Printf("Resolve fallback group failed: %v", err)
reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err))
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
if fallbackGroup.Platform != service.PlatformAnthropic ||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
reqLog.Warn("gateway.fallback_group_invalid",
zap.Int64("fallback_group_id", fallbackGroup.ID),
zap.String("fallback_platform", fallbackGroup.Platform),
zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType),
)
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
return
}
@@ -569,8 +614,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
@@ -578,22 +627,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: usedAccount,
Account: account,
Subscription: currentSubscription,
UserAgent: ua,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fcb,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", currentAPIKey.ID),
zap.Any("group_id", currentAPIKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, fs.ForceCacheBilling)
return
@@ -618,6 +672,17 @@ func (h *GatewayHandler) Models(c *gin.Context) {
groupID = &apiKey.Group.ID
platform = apiKey.Group.Platform
}
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" {
platform = forcedPlatform
}
if platform == service.PlatformSora {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": service.DefaultSoraModels(h.cfg),
})
return
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
@@ -917,6 +982,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
h.errorResponse(c, status, errType, message)
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
@@ -944,6 +1018,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gateway.count_tokens",
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
@@ -971,6 +1051,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
@@ -1004,14 +1085,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
// 错误响应已在 ForwardCountTokens 中处理
return
}
@@ -1275,7 +1357,25 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
logger.L().With(
zap.String("component", "handler.gateway.billing"),
zap.Error(err),
).Warn("gateway.billing_error_missing_message")
msg = "Billing error"
}
return http.StatusForbidden, "billing_error", msg
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}

View File

@@ -0,0 +1,49 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
assert.Equal(t, "error", parsed["type"])
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusTeapot, "already written")
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}

View File

@@ -4,8 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"math/rand/v2"
"net/http"
"strings"
"sync"
"time"
@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator()
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
// 解析请求体为 map
var bodyMap map[string]any
if len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
if c == nil || c.Request == nil {
return
}
// Fast path非 Claude CLI UA 直接判定 false避免热路径二次 JSON 反序列化。
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
c.Request = c.Request.WithContext(ctx)
return
}
// 验证是否为 Claude Code 客户端
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
isClaudeCode := false
if !strings.Contains(c.Request.URL.Path, "messages") {
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
isClaudeCode = true
} else {
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
var bodyMap map[string]any
if len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
}
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
@@ -104,31 +119,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
quit := make(chan struct{})
var stop func() bool
release := func() {
once.Do(func() {
if stop != nil {
_ = stop()
}
releaseFunc()
close(quit) // 通知监听 goroutine 退出
})
}
go func() {
select {
case <-ctx.Done():
// Context 取消时释放资源
release()
case <-quit:
// 正常释放已完成goroutine 退出
return
}
}()
stop = context.AfterFunc(ctx, release)
return release
}
@@ -153,6 +161,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
@@ -160,13 +194,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
if acquired {
return releaseFunc, nil
}
// Need to wait - handle streaming ping if needed
@@ -180,13 +214,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
if acquired {
return releaseFunc, nil
}
// Need to wait - handle streaming ping if needed
@@ -196,27 +230,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// Try immediate acquire first (avoid unnecessary wait)
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
acquireSlot := func() (*service.AcquireResult, error) {
if slotType == "user" {
return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
}
return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
if tryImmediate {
result, err := acquireSlot()
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
// Determine if ping is needed (streaming + ping format defined)
@@ -242,7 +278,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for {
select {
@@ -268,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
case <-timer.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
result, err := acquireSlot()
if err != nil {
return nil, err
}
@@ -284,7 +311,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
if result.Acquired {
return result.ReleaseFunc, nil
}
backoff = nextBackoff(backoff, rng)
backoff = nextBackoff(backoff)
timer.Reset(backoff)
}
}
@@ -292,26 +319,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil此时不添加抖动
// 返回值下一次退避时间100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
func nextBackoff(current time.Duration) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动jitter 范围 0.8 ~ 1.2
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4
jitter := 0.8 + rand.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff

View File

@@ -0,0 +1,106 @@
package handler
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 ---
func TestNextBackoff_ExponentialGrowth(t *testing.T) {
// 验证退避时间指数增长(乘数 1.5
// 由于有随机抖动±20%),需要验证范围
current := initialBackoff // 100ms
for i := 0; i < 10; i++ {
next := nextBackoff(current)
// 退避结果应在 [initialBackoff, maxBackoff] 范围内
assert.GreaterOrEqual(t, int64(next), int64(initialBackoff),
"第 %d 次退避不应低于初始值 %v", i, initialBackoff)
assert.LessOrEqual(t, int64(next), int64(maxBackoff),
"第 %d 次退避不应超过最大值 %v", i, maxBackoff)
// 为下一轮提供当前退避值
current = next
}
}
func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) {
// 即使输入非常大,输出也不超过 maxBackoff
for i := 0; i < 100; i++ {
result := nextBackoff(10 * time.Second)
assert.LessOrEqual(t, int64(result), int64(maxBackoff),
"退避值不应超过 maxBackoff")
}
}
func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) {
// 即使输入非常小,输出也不低于 initialBackoff
for i := 0; i < 100; i++ {
result := nextBackoff(1 * time.Millisecond)
assert.GreaterOrEqual(t, int64(result), int64(initialBackoff),
"退避值不应低于 initialBackoff")
}
}
func TestNextBackoff_HasJitter(t *testing.T) {
// 验证多次调用会产生不同的值(随机抖动生效)
// 使用相同的输入调用 50 次,收集结果
results := make(map[time.Duration]bool)
current := 500 * time.Millisecond
for i := 0; i < 50; i++ {
result := nextBackoff(current)
results[result] = true
}
// 50 次调用应该至少有 2 个不同的值(抖动存在)
require.Greater(t, len(results), 1,
"nextBackoff 应产生随机抖动,但所有 50 次调用结果相同")
}
func TestNextBackoff_InitialValueGrows(t *testing.T) {
// 验证从初始值开始,退避趋势是增长的
current := initialBackoff
var sum time.Duration
runs := 100
for i := 0; i < runs; i++ {
next := nextBackoff(current)
sum += next
current = next
}
avg := sum / time.Duration(runs)
// 平均退避时间应大于初始值(因为指数增长 + 上限)
assert.Greater(t, int64(avg), int64(initialBackoff),
"平均退避时间应大于初始退避值")
}
func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) {
// 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近
current := initialBackoff
for i := 0; i < 20; i++ {
current = nextBackoff(current)
}
// 经过 20 次迭代后,应该已经到达 maxBackoff 区间
// 由于抖动,允许 ±20% 的范围
lowerBound := time.Duration(float64(maxBackoff) * 0.8)
assert.GreaterOrEqual(t, int64(current), int64(lowerBound),
"经过多次退避后应收敛到 maxBackoff 附近")
}
func BenchmarkNextBackoff(b *testing.B) {
current := initialBackoff
for i := 0; i < b.N; i++ {
current = nextBackoff(current)
if current > maxBackoff {
current = initialBackoff
}
}
}

View File

@@ -0,0 +1,114 @@
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
type concurrencyCacheMock struct {
acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
releaseUserCalled int32
releaseAccountCalled int32
}
func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireAccountSlotFn != nil {
return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
atomic.AddInt32(&m.releaseAccountCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireUserSlotFn != nil {
return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
atomic.AddInt32(&m.releaseUserCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2)
require.NoError(t, err)
require.True(t, acquired)
require.NotNil(t, release)
release()
require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled))
}
func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) {
cache := &concurrencyCacheMock{
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1)
require.NoError(t, err)
require.False(t, acquired)
require.Nil(t, release)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled))
}

View File

@@ -0,0 +1,269 @@
package handler
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type helperConcurrencyCacheStub struct {
mu sync.Mutex
accountSeq []bool
userSeq []bool
accountAcquireCalls int
userAcquireCalls int
accountReleaseCalls int
userReleaseCalls int
}
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.accountAcquireCalls++
if len(s.accountSeq) == 0 {
return false, nil
}
v := s.accountSeq[0]
s.accountSeq = s.accountSeq[1:]
return v, nil
}
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.accountReleaseCalls++
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.userAcquireCalls++
if len(s.userSeq) == 0 {
return false, nil
}
v := s.userSeq[0]
s.userSeq = s.userSeq[1:]
return v, nil
}
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.userReleaseCalls++
return nil
}
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
}
return out, nil
}
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
out := make(map[int64]*service.UserLoadInfo, len(users))
for _, user := range users {
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
}
return out, nil
}
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(method, path, nil)
return c, rec
}
func validClaudeCodeBodyJSON() []byte {
return []byte(`{
"model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
}`)
}
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "curl/8.6.0")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
SetClaudeCodeClientContext(c, nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
c.Request.Header.Set("X-App", "claude-code")
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
// 缺少严格校验所需 header + body 字段
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
}
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, true},
userSeq: []bool{false, true},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
require.NoError(t, err)
require.NotNil(t, release)
require.False(t, streamStarted)
release()
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
})
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
require.NoError(t, err)
require.NotNil(t, release)
release()
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
})
}
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, false, false},
}
concurrency := service.NewConcurrencyService(cache)
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
})
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.True(t, streamStarted)
require.Contains(t, rec.Body.String(), ":\n\n")
})
}
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
errCache := &helperConcurrencyCacheStubWithError{
err: errors.New("redis unavailable"),
}
concurrency := service.NewConcurrencyService(errCache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
require.Nil(t, release)
require.Error(t, err)
require.Contains(t, err.Error(), "redis unavailable")
}
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false},
}
concurrency := service.NewConcurrencyService(cache)
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
streamStarted := false
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
require.Nil(t, release)
var cErr *ConcurrencyError
require.ErrorAs(t, err, &cErr)
require.True(t, cErr.IsTimeout)
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
}
type helperConcurrencyCacheStubWithError struct {
helperConcurrencyCacheStub
err error
}
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, s.err
}

View File

@@ -8,11 +8,9 @@ import (
"encoding/json"
"errors"
"io"
"log"
"net/http"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -20,11 +18,13 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
@@ -143,6 +143,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.gemini_v1beta.models",
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则要求 gemini 分组
if !middleware.HasForcePlatform(c) {
@@ -159,6 +166,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
stream := action == "streamGenerateContent"
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
body, err := io.ReadAll(c.Request.Body)
if err != nil {
@@ -187,8 +195,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err))
} else if !canWait {
reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait))
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
@@ -208,6 +217,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
@@ -223,6 +233,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return
@@ -252,6 +263,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
c.Request = c.Request.WithContext(ctx)
}
}
// === Gemini 内容摘要会话 Fallback 逻辑 ===
@@ -296,8 +316,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
matchedDigestChain = foundMatchedChain
sessionBoundAccountID = foundAccountID
geminiSessionUUID = foundUUID
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
reqLog.Info("gemini.digest_fallback_matched",
zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)),
zap.Int64("account_id", foundAccountID),
zap.String("digest_chain", truncateDigestChain(geminiDigestChain)),
)
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
@@ -351,18 +374,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
}
account := selection.Account
setOpsSelectedAccount(c, account.ID)
setOpsSelectedAccount(c, account.ID, account.Platform)
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
// 注意Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
reqLog.Info("gemini.sticky_session_account_switched",
zap.Int64("from_account_id", sessionBoundAccountID),
zap.Int64("to_account_id", account.ID),
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body)
sessionBoundAccountID = account.ID
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
// 无缓存绑定但请求里已有 thoughtSignature常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
// 为避免第一次转发就 400这里做一次确定性清理让新账号重新生成签名链路。
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
reqLog.Info("gemini.sticky_session_binding_missing",
zap.Bool("clean_thought_signature", true),
)
body = service.CleanGeminiNativeThoughtSignatures(body)
cleanedForUnknownBinding = true
sessionBoundAccountID = account.ID
@@ -381,9 +410,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted := false
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
reqLog.Info("gemini.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
@@ -405,6 +437,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
&streamStarted,
)
if err != nil {
reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
@@ -413,7 +446,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
@@ -436,8 +469,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch failoverAction {
case FailoverContinue:
continue
case FailoverExhausted:
@@ -448,7 +481,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
}
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
return
}
@@ -467,31 +500,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account.ID,
matchedDigestChain,
); err != nil {
log.Printf("[Gemini] Failed to save digest session: %v", err)
reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
UserAgent: userAgent,
IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fcb,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"),
zap.Int64("user_id", authSubject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", modelName),
zap.Int64("account_id", account.ID),
).Error("gemini.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP, fs.ForceCacheBilling)
})
reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", fs.SwitchCount),
)
return
}
}

View File

@@ -39,6 +39,7 @@ type Handlers struct {
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
Setting *SettingHandler
Totp *TotpHandler
}

View File

@@ -0,0 +1,65 @@
package handler
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
func executeUserIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
return
}
actorScope := "user:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "user:" + strconv.FormatInt(subject.UserID, 10)
}
result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close")
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope)
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}

View File

@@ -0,0 +1,285 @@
package handler
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userStoreUnavailableRepoStub struct{}
func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
type userMemoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub {
return &userMemoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func withUserSubject(userID int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
}
}
func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(nil)
var executed int
router := gin.New()
router.Use(withUserSubject(1))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, executed)
}
func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.Use(withUserSubject(2))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "k1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed)
}
func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newUserMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.Use(withUserSubject(3))
router.POST("/idempotent", func(c *gin.Context) {
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(80 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-user-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() { defer wg.Done(); status1, _ = call() }()
go func() { defer wg.Done(); status2, _ = call() }()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load())
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}

View File

@@ -0,0 +1,19 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger {
base := logger.L()
if c != nil && c.Request != nil {
base = logger.FromContext(c.Request.Context())
}
if component != "" {
fields = append([]zap.Field{zap.String("component", component)}, fields...)
}
return base.With(fields...)
}

View File

@@ -6,18 +6,19 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// OpenAIGatewayHandler handles OpenAI API gateway requests
@@ -25,6 +26,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
@@ -36,6 +38,7 @@ func NewOpenAIGatewayHandler(
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *OpenAIGatewayHandler {
@@ -51,6 +54,7 @@ func NewOpenAIGatewayHandler(
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
@@ -60,6 +64,8 @@ func NewOpenAIGatewayHandler(
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
requestStart := time.Now()
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
@@ -72,6 +78,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.responses",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
// Read request body
body, err := io.ReadAll(c.Request.Body)
@@ -91,57 +104,57 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
// Parse request body to map for potential modification
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
// 校验请求体 JSON 合法性
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Extract model and stream
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// 验证 model 必填
if reqModel == "" {
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) == "" {
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
reqBody["instructions"] = instructions
// Re-serialize body
body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
}
streamResult := gjson.GetBytes(body, "stream")
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
return
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id或 input 内存在带 call_id 的 tool_call/function_call
// 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err == nil {
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
}
}
}
}
@@ -157,34 +170,48 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get subscription info (may be nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
// 0. 先尝试直接抢占用户槽位(快速路径)
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// User slot acquired: no longer waiting.
waitCounted := false
if !userAcquired {
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if waitErr != nil {
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait {
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if waitErr == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
}
// 用户槽位已获取:退出等待队列计数。
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
@@ -197,14 +224,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
@@ -213,12 +240,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
reqLog.Warn("openai.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -229,8 +259,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
setOpsSelectedAccount(c, account.ID)
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID, account.Platform)
// 3. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
@@ -239,53 +269,87 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
c.Request.Context(),
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
if fastAcquired {
accountReleaseFunc = fastReleaseFunc
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
} else {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
reqLog.Info("openai.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -296,11 +360,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
// Error response already handled in Forward, just log
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("openai.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
@@ -308,27 +381,72 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Account: account,
Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai.record_usage_failed", zap.Error(err))
}
}(result, account, userAgent, clientIP)
})
reqLog.Debug("openai.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
}
func getContextInt64(c *gin.Context, key string) (int64, bool) {
if c == nil || key == "" {
return 0, false
}
v, ok := c.Get(key)
if !ok {
return 0, false
}
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
case int32:
return int64(t), true
case float64:
return int64(t), true
default:
return 0, false
}
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
@@ -397,8 +515,19 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in OpenAI SSE format
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
// Send error event in OpenAI SSE format with proper JSON marshaling
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
@@ -411,6 +540,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
h.errorResponse(c, status, errType, message)
}
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{

View File

@@ -0,0 +1,230 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号的消息",
errType: "server_error",
message: `upstream returned "invalid" response`,
},
{
name: "包含反斜杠的消息",
errType: "server_error",
message: `path C:\Users\test\file.txt not found`,
},
{
name: "包含双引号和反斜杠的消息",
errType: "upstream_error",
message: `error parsing "key\value": unexpected token`,
},
{
name: "包含换行符的消息",
errType: "server_error",
message: "line1\nline2\ttab",
},
{
name: "普通消息",
errType: "upstream_error",
message: "Upstream service temporarily unavailable",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
// 验证 SSE 格式event: error\ndata: {JSON}\n\n
assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头")
assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾")
// 提取 data 部分
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "应有 event 行和 data 行")
dataLine := lines[1]
require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头")
jsonStr := strings.TrimPrefix(dataLine, "data: ")
// 验证 JSON 合法性
var parsed map[string]any
err := json.Unmarshal([]byte(jsonStr), &parsed)
require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr)
// 验证结构
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "应包含 error 对象")
assert.Equal(t, tt.errType, errorObj["type"])
assert.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false)
// 非流式应返回 JSON 响应
assert.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "test error", errorObj["message"])
}
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusTeapot, "already written")
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantStream bool
}{
{"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true},
{"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false},
{"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false},
{"model 缺失", `{"stream":true}`, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := []byte(tt.body)
modelResult := gjson.GetBytes(body, "model")
model := ""
if modelResult.Type == gjson.String {
model = modelResult.String()
}
stream := gjson.GetBytes(body, "stream").Bool()
require.Equal(t, tt.wantModel, model)
require.Equal(t, tt.wantStream, stream)
})
}
}
// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验
func TestOpenAIHandler_GjsonValidation(t *testing.T) {
// 非法 JSON 被 gjson.ValidBytes 拦截
require.False(t, gjson.ValidBytes([]byte(`{invalid json`)))
// model 为数字 → 类型不是 gjson.String应被拒绝
body := []byte(`{"model":123}`)
modelResult := gjson.GetBytes(body, "model")
require.True(t, modelResult.Exists())
require.NotEqual(t, gjson.String, modelResult.Type)
// model 为 null → 类型不是 gjson.String应被拒绝
body2 := []byte(`{"model":null}`)
modelResult2 := gjson.GetBytes(body2, "model")
require.True(t, modelResult2.Exists())
require.NotEqual(t, gjson.String, modelResult2.Type)
// stream 为 string → 类型既不是 True 也不是 False应被拒绝
body3 := []byte(`{"model":"gpt-4","stream":"true"}`)
streamResult := gjson.GetBytes(body3, "stream")
require.True(t, streamResult.Exists())
require.NotEqual(t, gjson.True, streamResult.Type)
require.NotEqual(t, gjson.False, streamResult.Type)
// stream 为 int → 同上
body4 := []byte(`{"model":"gpt-4","stream":1}`)
streamResult2 := gjson.GetBytes(body4, "stream")
require.True(t, streamResult2.Exists())
require.NotEqual(t, gjson.True, streamResult2.Type)
require.NotEqual(t, gjson.False, streamResult2.Type)
}
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
// 测试 1无 instructions → 注入
body := []byte(`{"model":"gpt-4"}`)
existing := gjson.GetBytes(body, "instructions").String()
require.Empty(t, existing)
newBody, err := sjson.SetBytes(body, "instructions", "test instruction")
require.NoError(t, err)
require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String())
// 测试 2已有 instructions → 不覆盖
body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`)
existing2 := gjson.GetBytes(body2, "instructions").String()
require.Equal(t, "existing", existing2)
// 测试 3空白 instructions → 注入
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
require.Empty(t, existing3)
// 测试 4sjson.SetBytes 返回错误时不应 panic
// 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理
validBody := []byte(`{"model":"gpt-4"}`)
result, setErr := sjson.SetBytes(validBody, "instructions", "hello")
require.NoError(t, setErr)
require.True(t, gjson.ValidBytes(result))
}

View File

@@ -41,9 +41,8 @@ const (
)
type opsErrorLogJob struct {
ops *service.OpsService
entry *service.OpsInsertErrorLogInput
requestBody []byte
ops *service.OpsService
entry *service.OpsInsertErrorLogInput
}
var (
@@ -58,6 +57,7 @@ var (
opsErrorLogEnqueued atomic.Int64
opsErrorLogDropped atomic.Int64
opsErrorLogProcessed atomic.Int64
opsErrorLogSanitized atomic.Int64
opsErrorLogLastDropLogAt atomic.Int64
@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
}
}()
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry, job.requestBody)
_ = job.ops.RecordError(ctx, job.entry, nil)
cancel()
opsErrorLogProcessed.Add(1)
}()
@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
}
}
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) {
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
if ops == nil || entry == nil {
return
}
@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
}
select {
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}:
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
opsErrorLogQueueLen.Add(1)
opsErrorLogEnqueued.Add(1)
default:
@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
return opsErrorLogProcessed.Load()
}
func OpsErrorLogSanitizedTotal() int64 {
return opsErrorLogSanitized.Load()
}
func maybeLogOpsErrorLogDrop() {
now := time.Now().Unix()
@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
queueCap := OpsErrorLogQueueCapacity()
log.Printf(
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)",
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)",
queued,
queueCap,
opsErrorLogEnqueued.Load(),
opsErrorLogDropped.Load(),
opsErrorLogProcessed.Load(),
opsErrorLogSanitized.Load(),
)
}
@@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
if c == nil {
return
}
model = strings.TrimSpace(model)
c.Set(opsModelKey, model)
c.Set(opsStreamKey, stream)
if len(requestBody) > 0 {
c.Set(opsRequestBodyKey, requestBody)
}
if c.Request != nil && model != "" {
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
c.Request = c.Request.WithContext(ctx)
}
}
func setOpsSelectedAccount(c *gin.Context, accountID int64) {
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
v, ok := c.Get(opsRequestBodyKey)
if !ok {
return
}
raw, ok := v.([]byte)
if !ok || len(raw) == 0 {
return
}
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
opsErrorLogSanitized.Add(1)
}
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
if c == nil || accountID <= 0 {
return
}
c.Set(opsAccountIDKey, accountID)
if c.Request != nil {
ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID)
if len(platform) > 0 {
p := strings.TrimSpace(platform[0])
if p != "" {
ctx = context.WithValue(ctx, ctxkey.Platform, p)
}
}
c.Request = c.Request.WithContext(ctx)
}
}
type opsCaptureWriter struct {
@@ -507,6 +543,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0,
CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
if apiKey != nil {
entry.APIKeyID = &apiKey.ID
@@ -528,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP
}
var requestBody []byte
if v, ok := c.Get(opsRequestBodyKey); ok {
if b, ok := v.([]byte); ok && len(b) > 0 {
requestBody = b
}
}
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
attachOpsRequestBodyToEntry(c, entry)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
@@ -544,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
}
}
enqueueOpsErrorLog(ops, entry, requestBody)
enqueueOpsErrorLog(ops, entry)
return
}
@@ -632,6 +664,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0,
CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
// Capture upstream error context set by gateway services (if present).
// This does NOT affect the client response; it enriches Ops troubleshooting data.
@@ -707,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP
}
var requestBody []byte
if v, ok := c.Get(opsRequestBodyKey); ok {
if b, ok := v.([]byte); ok && len(b) > 0 {
requestBody = b
}
}
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
// Do NOT store Authorization/Cookie/etc.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
attachOpsRequestBodyToEntry(c, entry)
enqueueOpsErrorLog(ops, entry, requestBody)
enqueueOpsErrorLog(ops, entry)
}
}
@@ -760,6 +788,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string {
return &s
}
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey)
entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey)
entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey)
entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey)
entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey)
}
func getContextLatencyMs(c *gin.Context, key string) *int64 {
if c == nil || strings.TrimSpace(key) == "" {
return nil
}
v, ok := c.Get(key)
if !ok {
return nil
}
var ms int64
switch t := v.(type) {
case int:
ms = int64(t)
case int32:
ms = int64(t)
case int64:
ms = t
case float64:
ms = int64(t)
default:
return nil
}
if ms < 0 {
return nil
}
return &ms
}
type parsedOpsError struct {
ErrorType string
Message string

Some files were not shown because too many files have changed in this diff Show More