mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
73 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bab4bb9904 | ||
|
|
33bae6f49b | ||
|
|
32d619a56b | ||
|
|
642432cf2a | ||
|
|
61e9598b08 | ||
|
|
d4e34c7514 | ||
|
|
bfe7a5e452 | ||
|
|
77d916ffec | ||
|
|
831abf7977 | ||
|
|
817a491087 | ||
|
|
9a8dacc514 | ||
|
|
8adf80d98b | ||
|
|
62686a6213 | ||
|
|
3a089242f8 | ||
|
|
9d70c38504 | ||
|
|
aeb464f3ca | ||
|
|
7076717b20 | ||
|
|
c0a4fcea0a | ||
|
|
aa2b195c86 | ||
|
|
1d0872e7ca | ||
|
|
33988637b5 | ||
|
|
d4f6ad7225 | ||
|
|
078fefed03 | ||
|
|
5b10af85b4 | ||
|
|
4caf95e5dd | ||
|
|
8e1bcf53bb | ||
|
|
064f9be7e4 | ||
|
|
adcfb44cb7 | ||
|
|
3d79773ba2 | ||
|
|
6aa8cbbf20 | ||
|
|
742e73c9c2 | ||
|
|
f8de2bdedc | ||
|
|
59879b7fa7 | ||
|
|
27abae21b8 | ||
|
|
0819c8a51a | ||
|
|
9dcd3cd491 | ||
|
|
49767cccd2 | ||
|
|
29fb447daa | ||
|
|
f6fe5b552d | ||
|
|
bd0801a887 | ||
|
|
05b1c66aa8 | ||
|
|
80ae592c23 | ||
|
|
ba6de4c4d4 | ||
|
|
46ea9170cb | ||
|
|
7d318aeefa | ||
|
|
0aa3cf677a | ||
|
|
72961c5858 | ||
|
|
a05711a37a | ||
|
|
efc9e1d673 | ||
|
|
a11ac188c2 | ||
|
|
60350d298a | ||
|
|
838dad8759 | ||
|
|
a728dfe0c6 | ||
|
|
0c7cbe3566 | ||
|
|
832b0185c7 | ||
|
|
b1719b26d1 | ||
|
|
ccf6a921c7 | ||
|
|
197c570baa | ||
|
|
0fe09f1d40 | ||
|
|
4a91954532 | ||
|
|
b8b5cec35c | ||
|
|
43c203333e | ||
|
|
1c6393b131 | ||
|
|
22f04e72e5 | ||
|
|
5f3debf65b | ||
|
|
fd8ef27535 | ||
|
|
a80ec5d8bb | ||
|
|
530a16291c | ||
|
|
7be8f4dc6e | ||
|
|
9792b17597 | ||
|
|
99f1e3ff35 | ||
|
|
5ba71cd2f1 | ||
|
|
ec6bcfeb83 |
105
AGENTS.md
105
AGENTS.md
@@ -1,105 +0,0 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output.
|
||||
- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`.
|
||||
- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`).
|
||||
- `openspec/`: Spec-driven change docs (`changes/<id>/{proposal,design,tasks}.md`).
|
||||
- `tools/`: Utility scripts (security/perf checks).
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
```bash
|
||||
make build # Build backend + frontend
|
||||
make test # Backend tests + frontend lint/typecheck
|
||||
cd backend && make build # Build backend binary
|
||||
cd backend && make test-unit # Go unit tests
|
||||
cd backend && make test-integration # Go integration tests
|
||||
cd backend && make test # go test ./... + golangci-lint
|
||||
cd frontend && pnpm install --frozen-lockfile
|
||||
cd frontend && pnpm dev # Vite dev server
|
||||
cd frontend && pnpm build # Type-check + production build
|
||||
cd frontend && pnpm test:run # Vitest run
|
||||
cd frontend && pnpm test:coverage # Vitest + coverage report
|
||||
python3 tools/secret_scan.py # Secret scan
|
||||
```
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`).
|
||||
- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard).
|
||||
- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`.
|
||||
- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`.
|
||||
|
||||
## Go & Frontend Development Standards
|
||||
- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears.
|
||||
- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency.
|
||||
- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable.
|
||||
|
||||
### Go Performance Rules
|
||||
- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code.
|
||||
- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable.
|
||||
- Every external I/O path must use `context.Context` with explicit timeout/cancel.
|
||||
- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources.
|
||||
- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`).
|
||||
- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`.
|
||||
- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths.
|
||||
- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after).
|
||||
- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out.
|
||||
- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention.
|
||||
- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible.
|
||||
- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary.
|
||||
- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions.
|
||||
- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain.
|
||||
- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible.
|
||||
- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled.
|
||||
|
||||
### Data Access & Cache Rules
|
||||
- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR.
|
||||
- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints.
|
||||
- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths.
|
||||
- Keep transactions short; never perform external RPC/network calls inside DB transactions.
|
||||
- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`).
|
||||
- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks.
|
||||
- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd.
|
||||
- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths.
|
||||
|
||||
### Frontend Performance Rules
|
||||
- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components.
|
||||
- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs.
|
||||
- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it.
|
||||
- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed.
|
||||
- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows).
|
||||
- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking.
|
||||
- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review.
|
||||
- Avoid expensive inline computations in templates; move to cached `computed` selectors.
|
||||
- Keep state normalized; avoid duplicated derived state across multiple stores/components.
|
||||
- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import.
|
||||
- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`.
|
||||
- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks.
|
||||
- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes.
|
||||
|
||||
### Performance Budget & PR Evidence
|
||||
- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline.
|
||||
- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval.
|
||||
- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table).
|
||||
- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output).
|
||||
|
||||
### Quality Gate
|
||||
- Any changed code must include new or updated unit tests.
|
||||
- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules).
|
||||
- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description.
|
||||
|
||||
## Testing Guidelines
|
||||
- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant.
|
||||
- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`.
|
||||
- Enforce unit-test and coverage rules defined in `Quality Gate`.
|
||||
- Before opening a PR, run `make test` plus targeted tests for touched areas.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`.
|
||||
- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes.
|
||||
- For behavior/API changes, add or update `openspec/changes/...` artifacts.
|
||||
- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR.
|
||||
|
||||
## Security & Configuration Tips
|
||||
- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials.
|
||||
- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev.
|
||||
@@ -137,8 +137,6 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install
|
||||
|
||||
使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。
|
||||
|
||||
如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。
|
||||
|
||||
#### 前置条件
|
||||
|
||||
- Docker 20.10+
|
||||
|
||||
@@ -86,6 +86,7 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -216,6 +217,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -58,11 +58,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
promoCodeRepository := repository.NewPromoCodeRepository(client)
|
||||
billingCache := repository.NewBillingCache(redisClient)
|
||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
||||
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
||||
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig)
|
||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
|
||||
apiKeyService.SetRateLimitCacheInvalidator(billingCache)
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
@@ -194,7 +195,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
|
||||
scheduledTestPlanRepository := repository.NewScheduledTestPlanRepository(db)
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -221,10 +226,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -272,6 +278,7 @@ func provideCleanup(
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -401,6 +408,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"ScheduledTestRunnerService", func() error {
|
||||
if scheduledTestRunner != nil {
|
||||
scheduledTestRunner.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -37,12 +37,13 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
)
|
||||
accountExpirySvc := service.NewAccountExpiryService(nil, time.Second)
|
||||
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
|
||||
pricingSvc := service.NewPricingService(cfg, nil)
|
||||
emailQueueSvc := service.NewEmailQueueService(nil, 1)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
||||
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
||||
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
||||
@@ -73,6 +74,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
geminiOAuthSvc,
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -48,6 +48,24 @@ type APIKey struct {
|
||||
QuotaUsed float64 `json:"quota_used,omitempty"`
|
||||
// Expiration time for this API key (null = never expires)
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
// Rate limit in USD per 5 hours (0 = unlimited)
|
||||
RateLimit5h float64 `json:"rate_limit_5h,omitempty"`
|
||||
// Rate limit in USD per day (0 = unlimited)
|
||||
RateLimit1d float64 `json:"rate_limit_1d,omitempty"`
|
||||
// Rate limit in USD per 7 days (0 = unlimited)
|
||||
RateLimit7d float64 `json:"rate_limit_7d,omitempty"`
|
||||
// Used amount in USD for the current 5h window
|
||||
Usage5h float64 `json:"usage_5h,omitempty"`
|
||||
// Used amount in USD for the current 1d window
|
||||
Usage1d float64 `json:"usage_1d,omitempty"`
|
||||
// Used amount in USD for the current 7d window
|
||||
Usage7d float64 `json:"usage_7d,omitempty"`
|
||||
// Start time of the current 5h rate limit window
|
||||
Window5hStart *time.Time `json:"window_5h_start,omitempty"`
|
||||
// Start time of the current 1d rate limit window
|
||||
Window1dStart *time.Time `json:"window_1d_start,omitempty"`
|
||||
// Start time of the current 7d rate limit window
|
||||
Window7dStart *time.Time `json:"window_7d_start,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the APIKeyQuery when eager-loading is set.
|
||||
Edges APIKeyEdges `json:"edges"`
|
||||
@@ -105,13 +123,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
|
||||
switch columns[i] {
|
||||
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
|
||||
values[i] = new([]byte)
|
||||
case apikey.FieldQuota, apikey.FieldQuotaUsed:
|
||||
case apikey.FieldQuota, apikey.FieldQuotaUsed, apikey.FieldRateLimit5h, apikey.FieldRateLimit1d, apikey.FieldRateLimit7d, apikey.FieldUsage5h, apikey.FieldUsage1d, apikey.FieldUsage7d:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
|
||||
values[i] = new(sql.NullString)
|
||||
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt:
|
||||
case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt, apikey.FieldWindow5hStart, apikey.FieldWindow1dStart, apikey.FieldWindow7dStart:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
@@ -226,6 +244,63 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
|
||||
_m.ExpiresAt = new(time.Time)
|
||||
*_m.ExpiresAt = value.Time
|
||||
}
|
||||
case apikey.FieldRateLimit5h:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field rate_limit_5h", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RateLimit5h = value.Float64
|
||||
}
|
||||
case apikey.FieldRateLimit1d:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field rate_limit_1d", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RateLimit1d = value.Float64
|
||||
}
|
||||
case apikey.FieldRateLimit7d:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field rate_limit_7d", values[i])
|
||||
} else if value.Valid {
|
||||
_m.RateLimit7d = value.Float64
|
||||
}
|
||||
case apikey.FieldUsage5h:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field usage_5h", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Usage5h = value.Float64
|
||||
}
|
||||
case apikey.FieldUsage1d:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field usage_1d", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Usage1d = value.Float64
|
||||
}
|
||||
case apikey.FieldUsage7d:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field usage_7d", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Usage7d = value.Float64
|
||||
}
|
||||
case apikey.FieldWindow5hStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field window_5h_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Window5hStart = new(time.Time)
|
||||
*_m.Window5hStart = value.Time
|
||||
}
|
||||
case apikey.FieldWindow1dStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field window_1d_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Window1dStart = new(time.Time)
|
||||
*_m.Window1dStart = value.Time
|
||||
}
|
||||
case apikey.FieldWindow7dStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field window_7d_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Window7dStart = new(time.Time)
|
||||
*_m.Window7dStart = value.Time
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -326,6 +401,39 @@ func (_m *APIKey) String() string {
|
||||
builder.WriteString("expires_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("rate_limit_5h=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RateLimit5h))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("rate_limit_1d=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RateLimit1d))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("rate_limit_7d=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.RateLimit7d))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("usage_5h=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Usage5h))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("usage_1d=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Usage1d))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("usage_7d=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Usage7d))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.Window5hStart; v != nil {
|
||||
builder.WriteString("window_5h_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.Window1dStart; v != nil {
|
||||
builder.WriteString("window_1d_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.Window7dStart; v != nil {
|
||||
builder.WriteString("window_7d_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -43,6 +43,24 @@ const (
|
||||
FieldQuotaUsed = "quota_used"
|
||||
// FieldExpiresAt holds the string denoting the expires_at field in the database.
|
||||
FieldExpiresAt = "expires_at"
|
||||
// FieldRateLimit5h holds the string denoting the rate_limit_5h field in the database.
|
||||
FieldRateLimit5h = "rate_limit_5h"
|
||||
// FieldRateLimit1d holds the string denoting the rate_limit_1d field in the database.
|
||||
FieldRateLimit1d = "rate_limit_1d"
|
||||
// FieldRateLimit7d holds the string denoting the rate_limit_7d field in the database.
|
||||
FieldRateLimit7d = "rate_limit_7d"
|
||||
// FieldUsage5h holds the string denoting the usage_5h field in the database.
|
||||
FieldUsage5h = "usage_5h"
|
||||
// FieldUsage1d holds the string denoting the usage_1d field in the database.
|
||||
FieldUsage1d = "usage_1d"
|
||||
// FieldUsage7d holds the string denoting the usage_7d field in the database.
|
||||
FieldUsage7d = "usage_7d"
|
||||
// FieldWindow5hStart holds the string denoting the window_5h_start field in the database.
|
||||
FieldWindow5hStart = "window_5h_start"
|
||||
// FieldWindow1dStart holds the string denoting the window_1d_start field in the database.
|
||||
FieldWindow1dStart = "window_1d_start"
|
||||
// FieldWindow7dStart holds the string denoting the window_7d_start field in the database.
|
||||
FieldWindow7dStart = "window_7d_start"
|
||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||
EdgeUser = "user"
|
||||
// EdgeGroup holds the string denoting the group edge name in mutations.
|
||||
@@ -91,6 +109,15 @@ var Columns = []string{
|
||||
FieldQuota,
|
||||
FieldQuotaUsed,
|
||||
FieldExpiresAt,
|
||||
FieldRateLimit5h,
|
||||
FieldRateLimit1d,
|
||||
FieldRateLimit7d,
|
||||
FieldUsage5h,
|
||||
FieldUsage1d,
|
||||
FieldUsage7d,
|
||||
FieldWindow5hStart,
|
||||
FieldWindow1dStart,
|
||||
FieldWindow7dStart,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
@@ -129,6 +156,18 @@ var (
|
||||
DefaultQuota float64
|
||||
// DefaultQuotaUsed holds the default value on creation for the "quota_used" field.
|
||||
DefaultQuotaUsed float64
|
||||
// DefaultRateLimit5h holds the default value on creation for the "rate_limit_5h" field.
|
||||
DefaultRateLimit5h float64
|
||||
// DefaultRateLimit1d holds the default value on creation for the "rate_limit_1d" field.
|
||||
DefaultRateLimit1d float64
|
||||
// DefaultRateLimit7d holds the default value on creation for the "rate_limit_7d" field.
|
||||
DefaultRateLimit7d float64
|
||||
// DefaultUsage5h holds the default value on creation for the "usage_5h" field.
|
||||
DefaultUsage5h float64
|
||||
// DefaultUsage1d holds the default value on creation for the "usage_1d" field.
|
||||
DefaultUsage1d float64
|
||||
// DefaultUsage7d holds the default value on creation for the "usage_7d" field.
|
||||
DefaultUsage7d float64
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the APIKey queries.
|
||||
@@ -199,6 +238,51 @@ func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRateLimit5h orders the results by the rate_limit_5h field.
|
||||
func ByRateLimit5h(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRateLimit5h, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRateLimit1d orders the results by the rate_limit_1d field.
|
||||
func ByRateLimit1d(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRateLimit1d, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByRateLimit7d orders the results by the rate_limit_7d field.
|
||||
func ByRateLimit7d(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldRateLimit7d, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUsage5h orders the results by the usage_5h field.
|
||||
func ByUsage5h(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUsage5h, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUsage1d orders the results by the usage_1d field.
|
||||
func ByUsage1d(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUsage1d, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUsage7d orders the results by the usage_7d field.
|
||||
func ByUsage7d(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUsage7d, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWindow5hStart orders the results by the window_5h_start field.
|
||||
func ByWindow5hStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWindow5hStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWindow1dStart orders the results by the window_1d_start field.
|
||||
func ByWindow1dStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWindow1dStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWindow7dStart orders the results by the window_7d_start field.
|
||||
func ByWindow7dStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWindow7dStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUserField orders the results by user field.
|
||||
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -115,6 +115,51 @@ func ExpiresAt(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// RateLimit5h applies equality check predicate on the "rate_limit_5h" field. It's identical to RateLimit5hEQ.
|
||||
func RateLimit5h(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v))
|
||||
}
|
||||
|
||||
// RateLimit1d applies equality check predicate on the "rate_limit_1d" field. It's identical to RateLimit1dEQ.
|
||||
func RateLimit1d(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v))
|
||||
}
|
||||
|
||||
// RateLimit7d applies equality check predicate on the "rate_limit_7d" field. It's identical to RateLimit7dEQ.
|
||||
func RateLimit7d(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v))
|
||||
}
|
||||
|
||||
// Usage5h applies equality check predicate on the "usage_5h" field. It's identical to Usage5hEQ.
|
||||
func Usage5h(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v))
|
||||
}
|
||||
|
||||
// Usage1d applies equality check predicate on the "usage_1d" field. It's identical to Usage1dEQ.
|
||||
func Usage1d(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v))
|
||||
}
|
||||
|
||||
// Usage7d applies equality check predicate on the "usage_7d" field. It's identical to Usage7dEQ.
|
||||
func Usage7d(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v))
|
||||
}
|
||||
|
||||
// Window5hStart applies equality check predicate on the "window_5h_start" field. It's identical to Window5hStartEQ.
|
||||
func Window5hStart(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v))
|
||||
}
|
||||
|
||||
// Window1dStart applies equality check predicate on the "window_1d_start" field. It's identical to Window1dStartEQ.
|
||||
func Window1dStart(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v))
|
||||
}
|
||||
|
||||
// Window7dStart applies equality check predicate on the "window_7d_start" field. It's identical to Window7dStartEQ.
|
||||
func Window7dStart(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -690,6 +735,396 @@ func ExpiresAtNotNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// RateLimit5hEQ applies the EQ predicate on the "rate_limit_5h" field.
|
||||
func RateLimit5hEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v))
|
||||
}
|
||||
|
||||
// RateLimit5hNEQ applies the NEQ predicate on the "rate_limit_5h" field.
|
||||
func RateLimit5hNEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldRateLimit5h, v))
|
||||
}
|
||||
|
||||
// RateLimit5hIn applies the In predicate on the "rate_limit_5h" field.
|
||||
func RateLimit5hIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldRateLimit5h, vs...))
|
||||
}
|
||||
|
||||
// RateLimit5hNotIn applies the NotIn predicate on the "rate_limit_5h" field.
|
||||
func RateLimit5hNotIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldRateLimit5h, vs...))
|
||||
}
|
||||
|
||||
// RateLimit5hGT applies the GT predicate on the "rate_limit_5h" field.
|
||||
func RateLimit5hGT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldRateLimit5h, v))
|
||||
}
|
||||
|
||||
// RateLimit5hGTE applies the GTE predicate on the "rate_limit_5h" field.
|
||||
func RateLimit5hGTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldRateLimit5h, v))
|
||||
}
|
||||
|
||||
// RateLimit5hLT applies the LT predicate on the "rate_limit_5h" field.
|
||||
func RateLimit5hLT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldRateLimit5h, v))
|
||||
}
|
||||
|
||||
// RateLimit5hLTE applies the LTE predicate on the "rate_limit_5h" field.
|
||||
func RateLimit5hLTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldRateLimit5h, v))
|
||||
}
|
||||
|
||||
// RateLimit1dEQ applies the EQ predicate on the "rate_limit_1d" field.
|
||||
func RateLimit1dEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v))
|
||||
}
|
||||
|
||||
// RateLimit1dNEQ applies the NEQ predicate on the "rate_limit_1d" field.
|
||||
func RateLimit1dNEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldRateLimit1d, v))
|
||||
}
|
||||
|
||||
// RateLimit1dIn applies the In predicate on the "rate_limit_1d" field.
|
||||
func RateLimit1dIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldRateLimit1d, vs...))
|
||||
}
|
||||
|
||||
// RateLimit1dNotIn applies the NotIn predicate on the "rate_limit_1d" field.
|
||||
func RateLimit1dNotIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldRateLimit1d, vs...))
|
||||
}
|
||||
|
||||
// RateLimit1dGT applies the GT predicate on the "rate_limit_1d" field.
|
||||
func RateLimit1dGT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldRateLimit1d, v))
|
||||
}
|
||||
|
||||
// RateLimit1dGTE applies the GTE predicate on the "rate_limit_1d" field.
|
||||
func RateLimit1dGTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldRateLimit1d, v))
|
||||
}
|
||||
|
||||
// RateLimit1dLT applies the LT predicate on the "rate_limit_1d" field.
|
||||
func RateLimit1dLT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldRateLimit1d, v))
|
||||
}
|
||||
|
||||
// RateLimit1dLTE applies the LTE predicate on the "rate_limit_1d" field.
|
||||
func RateLimit1dLTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldRateLimit1d, v))
|
||||
}
|
||||
|
||||
// RateLimit7dEQ applies the EQ predicate on the "rate_limit_7d" field.
|
||||
func RateLimit7dEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v))
|
||||
}
|
||||
|
||||
// RateLimit7dNEQ applies the NEQ predicate on the "rate_limit_7d" field.
|
||||
func RateLimit7dNEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldRateLimit7d, v))
|
||||
}
|
||||
|
||||
// RateLimit7dIn applies the In predicate on the "rate_limit_7d" field.
|
||||
func RateLimit7dIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldRateLimit7d, vs...))
|
||||
}
|
||||
|
||||
// RateLimit7dNotIn applies the NotIn predicate on the "rate_limit_7d" field.
|
||||
func RateLimit7dNotIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldRateLimit7d, vs...))
|
||||
}
|
||||
|
||||
// RateLimit7dGT applies the GT predicate on the "rate_limit_7d" field.
|
||||
func RateLimit7dGT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldRateLimit7d, v))
|
||||
}
|
||||
|
||||
// RateLimit7dGTE applies the GTE predicate on the "rate_limit_7d" field.
|
||||
func RateLimit7dGTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldRateLimit7d, v))
|
||||
}
|
||||
|
||||
// RateLimit7dLT applies the LT predicate on the "rate_limit_7d" field.
|
||||
func RateLimit7dLT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldRateLimit7d, v))
|
||||
}
|
||||
|
||||
// RateLimit7dLTE applies the LTE predicate on the "rate_limit_7d" field.
|
||||
func RateLimit7dLTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldRateLimit7d, v))
|
||||
}
|
||||
|
||||
// Usage5hEQ applies the EQ predicate on the "usage_5h" field.
|
||||
func Usage5hEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v))
|
||||
}
|
||||
|
||||
// Usage5hNEQ applies the NEQ predicate on the "usage_5h" field.
|
||||
func Usage5hNEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldUsage5h, v))
|
||||
}
|
||||
|
||||
// Usage5hIn applies the In predicate on the "usage_5h" field.
|
||||
func Usage5hIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldUsage5h, vs...))
|
||||
}
|
||||
|
||||
// Usage5hNotIn applies the NotIn predicate on the "usage_5h" field.
|
||||
func Usage5hNotIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldUsage5h, vs...))
|
||||
}
|
||||
|
||||
// Usage5hGT applies the GT predicate on the "usage_5h" field.
|
||||
func Usage5hGT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldUsage5h, v))
|
||||
}
|
||||
|
||||
// Usage5hGTE applies the GTE predicate on the "usage_5h" field.
|
||||
func Usage5hGTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldUsage5h, v))
|
||||
}
|
||||
|
||||
// Usage5hLT applies the LT predicate on the "usage_5h" field.
|
||||
func Usage5hLT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldUsage5h, v))
|
||||
}
|
||||
|
||||
// Usage5hLTE applies the LTE predicate on the "usage_5h" field.
|
||||
func Usage5hLTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldUsage5h, v))
|
||||
}
|
||||
|
||||
// Usage1dEQ applies the EQ predicate on the "usage_1d" field.
|
||||
func Usage1dEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v))
|
||||
}
|
||||
|
||||
// Usage1dNEQ applies the NEQ predicate on the "usage_1d" field.
|
||||
func Usage1dNEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldUsage1d, v))
|
||||
}
|
||||
|
||||
// Usage1dIn applies the In predicate on the "usage_1d" field.
|
||||
func Usage1dIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldUsage1d, vs...))
|
||||
}
|
||||
|
||||
// Usage1dNotIn applies the NotIn predicate on the "usage_1d" field.
|
||||
func Usage1dNotIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldUsage1d, vs...))
|
||||
}
|
||||
|
||||
// Usage1dGT applies the GT predicate on the "usage_1d" field.
|
||||
func Usage1dGT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldUsage1d, v))
|
||||
}
|
||||
|
||||
// Usage1dGTE applies the GTE predicate on the "usage_1d" field.
|
||||
func Usage1dGTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldUsage1d, v))
|
||||
}
|
||||
|
||||
// Usage1dLT applies the LT predicate on the "usage_1d" field.
|
||||
func Usage1dLT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldUsage1d, v))
|
||||
}
|
||||
|
||||
// Usage1dLTE applies the LTE predicate on the "usage_1d" field.
|
||||
func Usage1dLTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldUsage1d, v))
|
||||
}
|
||||
|
||||
// Usage7dEQ applies the EQ predicate on the "usage_7d" field.
|
||||
func Usage7dEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v))
|
||||
}
|
||||
|
||||
// Usage7dNEQ applies the NEQ predicate on the "usage_7d" field.
|
||||
func Usage7dNEQ(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldUsage7d, v))
|
||||
}
|
||||
|
||||
// Usage7dIn applies the In predicate on the "usage_7d" field.
|
||||
func Usage7dIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldUsage7d, vs...))
|
||||
}
|
||||
|
||||
// Usage7dNotIn applies the NotIn predicate on the "usage_7d" field.
|
||||
func Usage7dNotIn(vs ...float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldUsage7d, vs...))
|
||||
}
|
||||
|
||||
// Usage7dGT applies the GT predicate on the "usage_7d" field.
|
||||
func Usage7dGT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldUsage7d, v))
|
||||
}
|
||||
|
||||
// Usage7dGTE applies the GTE predicate on the "usage_7d" field.
|
||||
func Usage7dGTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldUsage7d, v))
|
||||
}
|
||||
|
||||
// Usage7dLT applies the LT predicate on the "usage_7d" field.
|
||||
func Usage7dLT(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldUsage7d, v))
|
||||
}
|
||||
|
||||
// Usage7dLTE applies the LTE predicate on the "usage_7d" field.
|
||||
func Usage7dLTE(v float64) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldUsage7d, v))
|
||||
}
|
||||
|
||||
// Window5hStartEQ applies the EQ predicate on the "window_5h_start" field.
|
||||
func Window5hStartEQ(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v))
|
||||
}
|
||||
|
||||
// Window5hStartNEQ applies the NEQ predicate on the "window_5h_start" field.
|
||||
func Window5hStartNEQ(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldWindow5hStart, v))
|
||||
}
|
||||
|
||||
// Window5hStartIn applies the In predicate on the "window_5h_start" field.
|
||||
func Window5hStartIn(vs ...time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldWindow5hStart, vs...))
|
||||
}
|
||||
|
||||
// Window5hStartNotIn applies the NotIn predicate on the "window_5h_start" field.
|
||||
func Window5hStartNotIn(vs ...time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldWindow5hStart, vs...))
|
||||
}
|
||||
|
||||
// Window5hStartGT applies the GT predicate on the "window_5h_start" field.
|
||||
func Window5hStartGT(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldWindow5hStart, v))
|
||||
}
|
||||
|
||||
// Window5hStartGTE applies the GTE predicate on the "window_5h_start" field.
|
||||
func Window5hStartGTE(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldWindow5hStart, v))
|
||||
}
|
||||
|
||||
// Window5hStartLT applies the LT predicate on the "window_5h_start" field.
|
||||
func Window5hStartLT(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldWindow5hStart, v))
|
||||
}
|
||||
|
||||
// Window5hStartLTE applies the LTE predicate on the "window_5h_start" field.
|
||||
func Window5hStartLTE(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldWindow5hStart, v))
|
||||
}
|
||||
|
||||
// Window5hStartIsNil applies the IsNil predicate on the "window_5h_start" field.
|
||||
func Window5hStartIsNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIsNull(FieldWindow5hStart))
|
||||
}
|
||||
|
||||
// Window5hStartNotNil applies the NotNil predicate on the "window_5h_start" field.
|
||||
func Window5hStartNotNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotNull(FieldWindow5hStart))
|
||||
}
|
||||
|
||||
// Window1dStartEQ applies the EQ predicate on the "window_1d_start" field.
|
||||
func Window1dStartEQ(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v))
|
||||
}
|
||||
|
||||
// Window1dStartNEQ applies the NEQ predicate on the "window_1d_start" field.
|
||||
func Window1dStartNEQ(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldWindow1dStart, v))
|
||||
}
|
||||
|
||||
// Window1dStartIn applies the In predicate on the "window_1d_start" field.
|
||||
func Window1dStartIn(vs ...time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldWindow1dStart, vs...))
|
||||
}
|
||||
|
||||
// Window1dStartNotIn applies the NotIn predicate on the "window_1d_start" field.
|
||||
func Window1dStartNotIn(vs ...time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldWindow1dStart, vs...))
|
||||
}
|
||||
|
||||
// Window1dStartGT applies the GT predicate on the "window_1d_start" field.
|
||||
func Window1dStartGT(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldWindow1dStart, v))
|
||||
}
|
||||
|
||||
// Window1dStartGTE applies the GTE predicate on the "window_1d_start" field.
|
||||
func Window1dStartGTE(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldWindow1dStart, v))
|
||||
}
|
||||
|
||||
// Window1dStartLT applies the LT predicate on the "window_1d_start" field.
|
||||
func Window1dStartLT(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldWindow1dStart, v))
|
||||
}
|
||||
|
||||
// Window1dStartLTE applies the LTE predicate on the "window_1d_start" field.
|
||||
func Window1dStartLTE(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldWindow1dStart, v))
|
||||
}
|
||||
|
||||
// Window1dStartIsNil applies the IsNil predicate on the "window_1d_start" field.
|
||||
func Window1dStartIsNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIsNull(FieldWindow1dStart))
|
||||
}
|
||||
|
||||
// Window1dStartNotNil applies the NotNil predicate on the "window_1d_start" field.
|
||||
func Window1dStartNotNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotNull(FieldWindow1dStart))
|
||||
}
|
||||
|
||||
// Window7dStartEQ applies the EQ predicate on the "window_7d_start" field.
|
||||
func Window7dStartEQ(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v))
|
||||
}
|
||||
|
||||
// Window7dStartNEQ applies the NEQ predicate on the "window_7d_start" field.
|
||||
func Window7dStartNEQ(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNEQ(FieldWindow7dStart, v))
|
||||
}
|
||||
|
||||
// Window7dStartIn applies the In predicate on the "window_7d_start" field.
|
||||
func Window7dStartIn(vs ...time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIn(FieldWindow7dStart, vs...))
|
||||
}
|
||||
|
||||
// Window7dStartNotIn applies the NotIn predicate on the "window_7d_start" field.
|
||||
func Window7dStartNotIn(vs ...time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotIn(FieldWindow7dStart, vs...))
|
||||
}
|
||||
|
||||
// Window7dStartGT applies the GT predicate on the "window_7d_start" field.
|
||||
func Window7dStartGT(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGT(FieldWindow7dStart, v))
|
||||
}
|
||||
|
||||
// Window7dStartGTE applies the GTE predicate on the "window_7d_start" field.
|
||||
func Window7dStartGTE(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldGTE(FieldWindow7dStart, v))
|
||||
}
|
||||
|
||||
// Window7dStartLT applies the LT predicate on the "window_7d_start" field.
|
||||
func Window7dStartLT(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLT(FieldWindow7dStart, v))
|
||||
}
|
||||
|
||||
// Window7dStartLTE applies the LTE predicate on the "window_7d_start" field.
|
||||
func Window7dStartLTE(v time.Time) predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldLTE(FieldWindow7dStart, v))
|
||||
}
|
||||
|
||||
// Window7dStartIsNil applies the IsNil predicate on the "window_7d_start" field.
|
||||
func Window7dStartIsNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldIsNull(FieldWindow7dStart))
|
||||
}
|
||||
|
||||
// Window7dStartNotNil applies the NotNil predicate on the "window_7d_start" field.
|
||||
func Window7dStartNotNil() predicate.APIKey {
|
||||
return predicate.APIKey(sql.FieldNotNull(FieldWindow7dStart))
|
||||
}
|
||||
|
||||
// HasUser applies the HasEdge predicate on the "user" edge.
|
||||
func HasUser() predicate.APIKey {
|
||||
return predicate.APIKey(func(s *sql.Selector) {
|
||||
|
||||
@@ -181,6 +181,132 @@ func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRateLimit5h sets the "rate_limit_5h" field.
|
||||
func (_c *APIKeyCreate) SetRateLimit5h(v float64) *APIKeyCreate {
|
||||
_c.mutation.SetRateLimit5h(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableRateLimit5h(v *float64) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetRateLimit5h(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRateLimit1d sets the "rate_limit_1d" field.
|
||||
func (_c *APIKeyCreate) SetRateLimit1d(v float64) *APIKeyCreate {
|
||||
_c.mutation.SetRateLimit1d(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableRateLimit1d(v *float64) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetRateLimit1d(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRateLimit7d sets the "rate_limit_7d" field.
|
||||
func (_c *APIKeyCreate) SetRateLimit7d(v float64) *APIKeyCreate {
|
||||
_c.mutation.SetRateLimit7d(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableRateLimit7d(v *float64) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetRateLimit7d(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUsage5h sets the "usage_5h" field.
|
||||
func (_c *APIKeyCreate) SetUsage5h(v float64) *APIKeyCreate {
|
||||
_c.mutation.SetUsage5h(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableUsage5h(v *float64) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetUsage5h(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUsage1d sets the "usage_1d" field.
|
||||
func (_c *APIKeyCreate) SetUsage1d(v float64) *APIKeyCreate {
|
||||
_c.mutation.SetUsage1d(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableUsage1d(v *float64) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetUsage1d(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUsage7d sets the "usage_7d" field.
|
||||
func (_c *APIKeyCreate) SetUsage7d(v float64) *APIKeyCreate {
|
||||
_c.mutation.SetUsage7d(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableUsage7d(v *float64) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetUsage7d(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetWindow5hStart sets the "window_5h_start" field.
|
||||
func (_c *APIKeyCreate) SetWindow5hStart(v time.Time) *APIKeyCreate {
|
||||
_c.mutation.SetWindow5hStart(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableWindow5hStart(v *time.Time) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetWindow5hStart(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetWindow1dStart sets the "window_1d_start" field.
|
||||
func (_c *APIKeyCreate) SetWindow1dStart(v time.Time) *APIKeyCreate {
|
||||
_c.mutation.SetWindow1dStart(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableWindow1dStart(v *time.Time) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetWindow1dStart(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetWindow7dStart sets the "window_7d_start" field.
|
||||
func (_c *APIKeyCreate) SetWindow7dStart(v time.Time) *APIKeyCreate {
|
||||
_c.mutation.SetWindow7dStart(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil.
|
||||
func (_c *APIKeyCreate) SetNillableWindow7dStart(v *time.Time) *APIKeyCreate {
|
||||
if v != nil {
|
||||
_c.SetWindow7dStart(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
|
||||
return _c.SetUserID(v.ID)
|
||||
@@ -269,6 +395,30 @@ func (_c *APIKeyCreate) defaults() error {
|
||||
v := apikey.DefaultQuotaUsed
|
||||
_c.mutation.SetQuotaUsed(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RateLimit5h(); !ok {
|
||||
v := apikey.DefaultRateLimit5h
|
||||
_c.mutation.SetRateLimit5h(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RateLimit1d(); !ok {
|
||||
v := apikey.DefaultRateLimit1d
|
||||
_c.mutation.SetRateLimit1d(v)
|
||||
}
|
||||
if _, ok := _c.mutation.RateLimit7d(); !ok {
|
||||
v := apikey.DefaultRateLimit7d
|
||||
_c.mutation.SetRateLimit7d(v)
|
||||
}
|
||||
if _, ok := _c.mutation.Usage5h(); !ok {
|
||||
v := apikey.DefaultUsage5h
|
||||
_c.mutation.SetUsage5h(v)
|
||||
}
|
||||
if _, ok := _c.mutation.Usage1d(); !ok {
|
||||
v := apikey.DefaultUsage1d
|
||||
_c.mutation.SetUsage1d(v)
|
||||
}
|
||||
if _, ok := _c.mutation.Usage7d(); !ok {
|
||||
v := apikey.DefaultUsage7d
|
||||
_c.mutation.SetUsage7d(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -313,6 +463,24 @@ func (_c *APIKeyCreate) check() error {
|
||||
if _, ok := _c.mutation.QuotaUsed(); !ok {
|
||||
return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RateLimit5h(); !ok {
|
||||
return &ValidationError{Name: "rate_limit_5h", err: errors.New(`ent: missing required field "APIKey.rate_limit_5h"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RateLimit1d(); !ok {
|
||||
return &ValidationError{Name: "rate_limit_1d", err: errors.New(`ent: missing required field "APIKey.rate_limit_1d"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.RateLimit7d(); !ok {
|
||||
return &ValidationError{Name: "rate_limit_7d", err: errors.New(`ent: missing required field "APIKey.rate_limit_7d"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.Usage5h(); !ok {
|
||||
return &ValidationError{Name: "usage_5h", err: errors.New(`ent: missing required field "APIKey.usage_5h"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.Usage1d(); !ok {
|
||||
return &ValidationError{Name: "usage_1d", err: errors.New(`ent: missing required field "APIKey.usage_1d"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.Usage7d(); !ok {
|
||||
return &ValidationError{Name: "usage_7d", err: errors.New(`ent: missing required field "APIKey.usage_7d"`)}
|
||||
}
|
||||
if len(_c.mutation.UserIDs()) == 0 {
|
||||
return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)}
|
||||
}
|
||||
@@ -391,6 +559,42 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value)
|
||||
_node.ExpiresAt = &value
|
||||
}
|
||||
if value, ok := _c.mutation.RateLimit5h(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
|
||||
_node.RateLimit5h = value
|
||||
}
|
||||
if value, ok := _c.mutation.RateLimit1d(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
|
||||
_node.RateLimit1d = value
|
||||
}
|
||||
if value, ok := _c.mutation.RateLimit7d(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
|
||||
_node.RateLimit7d = value
|
||||
}
|
||||
if value, ok := _c.mutation.Usage5h(); ok {
|
||||
_spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value)
|
||||
_node.Usage5h = value
|
||||
}
|
||||
if value, ok := _c.mutation.Usage1d(); ok {
|
||||
_spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value)
|
||||
_node.Usage1d = value
|
||||
}
|
||||
if value, ok := _c.mutation.Usage7d(); ok {
|
||||
_spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value)
|
||||
_node.Usage7d = value
|
||||
}
|
||||
if value, ok := _c.mutation.Window5hStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value)
|
||||
_node.Window5hStart = &value
|
||||
}
|
||||
if value, ok := _c.mutation.Window1dStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value)
|
||||
_node.Window1dStart = &value
|
||||
}
|
||||
if value, ok := _c.mutation.Window7dStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value)
|
||||
_node.Window7dStart = &value
|
||||
}
|
||||
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
@@ -697,6 +901,168 @@ func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRateLimit5h sets the "rate_limit_5h" field.
|
||||
func (u *APIKeyUpsert) SetRateLimit5h(v float64) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldRateLimit5h, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateRateLimit5h() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldRateLimit5h)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddRateLimit5h adds v to the "rate_limit_5h" field.
|
||||
func (u *APIKeyUpsert) AddRateLimit5h(v float64) *APIKeyUpsert {
|
||||
u.Add(apikey.FieldRateLimit5h, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRateLimit1d sets the "rate_limit_1d" field.
|
||||
func (u *APIKeyUpsert) SetRateLimit1d(v float64) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldRateLimit1d, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateRateLimit1d() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldRateLimit1d)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddRateLimit1d adds v to the "rate_limit_1d" field.
|
||||
func (u *APIKeyUpsert) AddRateLimit1d(v float64) *APIKeyUpsert {
|
||||
u.Add(apikey.FieldRateLimit1d, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetRateLimit7d sets the "rate_limit_7d" field.
|
||||
func (u *APIKeyUpsert) SetRateLimit7d(v float64) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldRateLimit7d, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateRateLimit7d() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldRateLimit7d)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddRateLimit7d adds v to the "rate_limit_7d" field.
|
||||
func (u *APIKeyUpsert) AddRateLimit7d(v float64) *APIKeyUpsert {
|
||||
u.Add(apikey.FieldRateLimit7d, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUsage5h sets the "usage_5h" field.
|
||||
func (u *APIKeyUpsert) SetUsage5h(v float64) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldUsage5h, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateUsage5h() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldUsage5h)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddUsage5h adds v to the "usage_5h" field.
|
||||
func (u *APIKeyUpsert) AddUsage5h(v float64) *APIKeyUpsert {
|
||||
u.Add(apikey.FieldUsage5h, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUsage1d sets the "usage_1d" field.
|
||||
func (u *APIKeyUpsert) SetUsage1d(v float64) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldUsage1d, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateUsage1d() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldUsage1d)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddUsage1d adds v to the "usage_1d" field.
|
||||
func (u *APIKeyUpsert) AddUsage1d(v float64) *APIKeyUpsert {
|
||||
u.Add(apikey.FieldUsage1d, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetUsage7d sets the "usage_7d" field.
|
||||
func (u *APIKeyUpsert) SetUsage7d(v float64) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldUsage7d, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateUsage7d() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldUsage7d)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddUsage7d adds v to the "usage_7d" field.
|
||||
func (u *APIKeyUpsert) AddUsage7d(v float64) *APIKeyUpsert {
|
||||
u.Add(apikey.FieldUsage7d, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetWindow5hStart sets the "window_5h_start" field.
|
||||
func (u *APIKeyUpsert) SetWindow5hStart(v time.Time) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldWindow5hStart, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateWindow5hStart() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldWindow5hStart)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearWindow5hStart clears the value of the "window_5h_start" field.
|
||||
func (u *APIKeyUpsert) ClearWindow5hStart() *APIKeyUpsert {
|
||||
u.SetNull(apikey.FieldWindow5hStart)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetWindow1dStart sets the "window_1d_start" field.
|
||||
func (u *APIKeyUpsert) SetWindow1dStart(v time.Time) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldWindow1dStart, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateWindow1dStart() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldWindow1dStart)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearWindow1dStart clears the value of the "window_1d_start" field.
|
||||
func (u *APIKeyUpsert) ClearWindow1dStart() *APIKeyUpsert {
|
||||
u.SetNull(apikey.FieldWindow1dStart)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetWindow7dStart sets the "window_7d_start" field.
|
||||
func (u *APIKeyUpsert) SetWindow7dStart(v time.Time) *APIKeyUpsert {
|
||||
u.Set(apikey.FieldWindow7dStart, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsert) UpdateWindow7dStart() *APIKeyUpsert {
|
||||
u.SetExcluded(apikey.FieldWindow7dStart)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearWindow7dStart clears the value of the "window_7d_start" field.
|
||||
func (u *APIKeyUpsert) ClearWindow7dStart() *APIKeyUpsert {
|
||||
u.SetNull(apikey.FieldWindow7dStart)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -980,6 +1346,195 @@ func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRateLimit5h sets the "rate_limit_5h" field.
|
||||
func (u *APIKeyUpsertOne) SetRateLimit5h(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetRateLimit5h(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRateLimit5h adds v to the "rate_limit_5h" field.
|
||||
func (u *APIKeyUpsertOne) AddRateLimit5h(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddRateLimit5h(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateRateLimit5h() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateRateLimit5h()
|
||||
})
|
||||
}
|
||||
|
||||
// SetRateLimit1d sets the "rate_limit_1d" field.
|
||||
func (u *APIKeyUpsertOne) SetRateLimit1d(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetRateLimit1d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRateLimit1d adds v to the "rate_limit_1d" field.
|
||||
func (u *APIKeyUpsertOne) AddRateLimit1d(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddRateLimit1d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateRateLimit1d() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateRateLimit1d()
|
||||
})
|
||||
}
|
||||
|
||||
// SetRateLimit7d sets the "rate_limit_7d" field.
|
||||
func (u *APIKeyUpsertOne) SetRateLimit7d(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetRateLimit7d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRateLimit7d adds v to the "rate_limit_7d" field.
|
||||
func (u *APIKeyUpsertOne) AddRateLimit7d(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddRateLimit7d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateRateLimit7d() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateRateLimit7d()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUsage5h sets the "usage_5h" field.
|
||||
func (u *APIKeyUpsertOne) SetUsage5h(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetUsage5h(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddUsage5h adds v to the "usage_5h" field.
|
||||
func (u *APIKeyUpsertOne) AddUsage5h(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddUsage5h(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateUsage5h() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateUsage5h()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUsage1d sets the "usage_1d" field.
|
||||
func (u *APIKeyUpsertOne) SetUsage1d(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetUsage1d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddUsage1d adds v to the "usage_1d" field.
|
||||
func (u *APIKeyUpsertOne) AddUsage1d(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddUsage1d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateUsage1d() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateUsage1d()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUsage7d sets the "usage_7d" field.
|
||||
func (u *APIKeyUpsertOne) SetUsage7d(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetUsage7d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddUsage7d adds v to the "usage_7d" field.
|
||||
func (u *APIKeyUpsertOne) AddUsage7d(v float64) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddUsage7d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateUsage7d() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateUsage7d()
|
||||
})
|
||||
}
|
||||
|
||||
// SetWindow5hStart sets the "window_5h_start" field.
|
||||
func (u *APIKeyUpsertOne) SetWindow5hStart(v time.Time) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetWindow5hStart(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateWindow5hStart() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateWindow5hStart()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearWindow5hStart clears the value of the "window_5h_start" field.
|
||||
func (u *APIKeyUpsertOne) ClearWindow5hStart() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearWindow5hStart()
|
||||
})
|
||||
}
|
||||
|
||||
// SetWindow1dStart sets the "window_1d_start" field.
|
||||
func (u *APIKeyUpsertOne) SetWindow1dStart(v time.Time) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetWindow1dStart(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateWindow1dStart() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateWindow1dStart()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearWindow1dStart clears the value of the "window_1d_start" field.
|
||||
func (u *APIKeyUpsertOne) ClearWindow1dStart() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearWindow1dStart()
|
||||
})
|
||||
}
|
||||
|
||||
// SetWindow7dStart sets the "window_7d_start" field.
|
||||
func (u *APIKeyUpsertOne) SetWindow7dStart(v time.Time) *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetWindow7dStart(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertOne) UpdateWindow7dStart() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateWindow7dStart()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearWindow7dStart clears the value of the "window_7d_start" field.
|
||||
func (u *APIKeyUpsertOne) ClearWindow7dStart() *APIKeyUpsertOne {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearWindow7dStart()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -1429,6 +1984,195 @@ func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetRateLimit5h sets the "rate_limit_5h" field.
|
||||
func (u *APIKeyUpsertBulk) SetRateLimit5h(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetRateLimit5h(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRateLimit5h adds v to the "rate_limit_5h" field.
|
||||
func (u *APIKeyUpsertBulk) AddRateLimit5h(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddRateLimit5h(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateRateLimit5h() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateRateLimit5h()
|
||||
})
|
||||
}
|
||||
|
||||
// SetRateLimit1d sets the "rate_limit_1d" field.
|
||||
func (u *APIKeyUpsertBulk) SetRateLimit1d(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetRateLimit1d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRateLimit1d adds v to the "rate_limit_1d" field.
|
||||
func (u *APIKeyUpsertBulk) AddRateLimit1d(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddRateLimit1d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateRateLimit1d() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateRateLimit1d()
|
||||
})
|
||||
}
|
||||
|
||||
// SetRateLimit7d sets the "rate_limit_7d" field.
|
||||
func (u *APIKeyUpsertBulk) SetRateLimit7d(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetRateLimit7d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddRateLimit7d adds v to the "rate_limit_7d" field.
|
||||
func (u *APIKeyUpsertBulk) AddRateLimit7d(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddRateLimit7d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateRateLimit7d() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateRateLimit7d()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUsage5h sets the "usage_5h" field.
|
||||
func (u *APIKeyUpsertBulk) SetUsage5h(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetUsage5h(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddUsage5h adds v to the "usage_5h" field.
|
||||
func (u *APIKeyUpsertBulk) AddUsage5h(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddUsage5h(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateUsage5h() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateUsage5h()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUsage1d sets the "usage_1d" field.
|
||||
func (u *APIKeyUpsertBulk) SetUsage1d(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetUsage1d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddUsage1d adds v to the "usage_1d" field.
|
||||
func (u *APIKeyUpsertBulk) AddUsage1d(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddUsage1d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateUsage1d() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateUsage1d()
|
||||
})
|
||||
}
|
||||
|
||||
// SetUsage7d sets the "usage_7d" field.
|
||||
func (u *APIKeyUpsertBulk) SetUsage7d(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetUsage7d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddUsage7d adds v to the "usage_7d" field.
|
||||
func (u *APIKeyUpsertBulk) AddUsage7d(v float64) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.AddUsage7d(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateUsage7d() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateUsage7d()
|
||||
})
|
||||
}
|
||||
|
||||
// SetWindow5hStart sets the "window_5h_start" field.
|
||||
func (u *APIKeyUpsertBulk) SetWindow5hStart(v time.Time) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetWindow5hStart(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateWindow5hStart() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateWindow5hStart()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearWindow5hStart clears the value of the "window_5h_start" field.
|
||||
func (u *APIKeyUpsertBulk) ClearWindow5hStart() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearWindow5hStart()
|
||||
})
|
||||
}
|
||||
|
||||
// SetWindow1dStart sets the "window_1d_start" field.
|
||||
func (u *APIKeyUpsertBulk) SetWindow1dStart(v time.Time) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetWindow1dStart(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateWindow1dStart() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateWindow1dStart()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearWindow1dStart clears the value of the "window_1d_start" field.
|
||||
func (u *APIKeyUpsertBulk) ClearWindow1dStart() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearWindow1dStart()
|
||||
})
|
||||
}
|
||||
|
||||
// SetWindow7dStart sets the "window_7d_start" field.
|
||||
func (u *APIKeyUpsertBulk) SetWindow7dStart(v time.Time) *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.SetWindow7dStart(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create.
|
||||
func (u *APIKeyUpsertBulk) UpdateWindow7dStart() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.UpdateWindow7dStart()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearWindow7dStart clears the value of the "window_7d_start" field.
|
||||
func (u *APIKeyUpsertBulk) ClearWindow7dStart() *APIKeyUpsertBulk {
|
||||
return u.Update(func(s *APIKeyUpsert) {
|
||||
s.ClearWindow7dStart()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -252,6 +252,192 @@ func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRateLimit5h sets the "rate_limit_5h" field.
|
||||
func (_u *APIKeyUpdate) SetRateLimit5h(v float64) *APIKeyUpdate {
|
||||
_u.mutation.ResetRateLimit5h()
|
||||
_u.mutation.SetRateLimit5h(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableRateLimit5h(v *float64) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetRateLimit5h(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRateLimit5h adds value to the "rate_limit_5h" field.
|
||||
func (_u *APIKeyUpdate) AddRateLimit5h(v float64) *APIKeyUpdate {
|
||||
_u.mutation.AddRateLimit5h(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRateLimit1d sets the "rate_limit_1d" field.
|
||||
func (_u *APIKeyUpdate) SetRateLimit1d(v float64) *APIKeyUpdate {
|
||||
_u.mutation.ResetRateLimit1d()
|
||||
_u.mutation.SetRateLimit1d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableRateLimit1d(v *float64) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetRateLimit1d(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRateLimit1d adds value to the "rate_limit_1d" field.
|
||||
func (_u *APIKeyUpdate) AddRateLimit1d(v float64) *APIKeyUpdate {
|
||||
_u.mutation.AddRateLimit1d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRateLimit7d sets the "rate_limit_7d" field.
|
||||
func (_u *APIKeyUpdate) SetRateLimit7d(v float64) *APIKeyUpdate {
|
||||
_u.mutation.ResetRateLimit7d()
|
||||
_u.mutation.SetRateLimit7d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableRateLimit7d(v *float64) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetRateLimit7d(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRateLimit7d adds value to the "rate_limit_7d" field.
|
||||
func (_u *APIKeyUpdate) AddRateLimit7d(v float64) *APIKeyUpdate {
|
||||
_u.mutation.AddRateLimit7d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsage5h sets the "usage_5h" field.
|
||||
func (_u *APIKeyUpdate) SetUsage5h(v float64) *APIKeyUpdate {
|
||||
_u.mutation.ResetUsage5h()
|
||||
_u.mutation.SetUsage5h(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableUsage5h(v *float64) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetUsage5h(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsage5h adds value to the "usage_5h" field.
|
||||
func (_u *APIKeyUpdate) AddUsage5h(v float64) *APIKeyUpdate {
|
||||
_u.mutation.AddUsage5h(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsage1d sets the "usage_1d" field.
|
||||
func (_u *APIKeyUpdate) SetUsage1d(v float64) *APIKeyUpdate {
|
||||
_u.mutation.ResetUsage1d()
|
||||
_u.mutation.SetUsage1d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableUsage1d(v *float64) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetUsage1d(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsage1d adds value to the "usage_1d" field.
|
||||
func (_u *APIKeyUpdate) AddUsage1d(v float64) *APIKeyUpdate {
|
||||
_u.mutation.AddUsage1d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsage7d sets the "usage_7d" field.
|
||||
func (_u *APIKeyUpdate) SetUsage7d(v float64) *APIKeyUpdate {
|
||||
_u.mutation.ResetUsage7d()
|
||||
_u.mutation.SetUsage7d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableUsage7d(v *float64) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetUsage7d(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsage7d adds value to the "usage_7d" field.
|
||||
func (_u *APIKeyUpdate) AddUsage7d(v float64) *APIKeyUpdate {
|
||||
_u.mutation.AddUsage7d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWindow5hStart sets the "window_5h_start" field.
|
||||
func (_u *APIKeyUpdate) SetWindow5hStart(v time.Time) *APIKeyUpdate {
|
||||
_u.mutation.SetWindow5hStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetWindow5hStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWindow5hStart clears the value of the "window_5h_start" field.
|
||||
func (_u *APIKeyUpdate) ClearWindow5hStart() *APIKeyUpdate {
|
||||
_u.mutation.ClearWindow5hStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWindow1dStart sets the "window_1d_start" field.
|
||||
func (_u *APIKeyUpdate) SetWindow1dStart(v time.Time) *APIKeyUpdate {
|
||||
_u.mutation.SetWindow1dStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetWindow1dStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWindow1dStart clears the value of the "window_1d_start" field.
|
||||
func (_u *APIKeyUpdate) ClearWindow1dStart() *APIKeyUpdate {
|
||||
_u.mutation.ClearWindow1dStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWindow7dStart sets the "window_7d_start" field.
|
||||
func (_u *APIKeyUpdate) SetWindow7dStart(v time.Time) *APIKeyUpdate {
|
||||
_u.mutation.SetWindow7dStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdate) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdate {
|
||||
if v != nil {
|
||||
_u.SetWindow7dStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWindow7dStart clears the value of the "window_7d_start" field.
|
||||
func (_u *APIKeyUpdate) ClearWindow7dStart() *APIKeyUpdate {
|
||||
_u.mutation.ClearWindow7dStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
|
||||
return _u.SetUserID(v.ID)
|
||||
@@ -456,6 +642,60 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(apikey.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.RateLimit5h(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRateLimit5h(); ok {
|
||||
_spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RateLimit1d(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRateLimit1d(); ok {
|
||||
_spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RateLimit7d(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRateLimit7d(); ok {
|
||||
_spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Usage5h(); ok {
|
||||
_spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedUsage5h(); ok {
|
||||
_spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Usage1d(); ok {
|
||||
_spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedUsage1d(); ok {
|
||||
_spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Usage7d(); ok {
|
||||
_spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedUsage7d(); ok {
|
||||
_spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Window5hStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.Window5hStartCleared() {
|
||||
_spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Window1dStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.Window1dStartCleared() {
|
||||
_spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Window7dStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.Window7dStartCleared() {
|
||||
_spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
@@ -799,6 +1039,192 @@ func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRateLimit5h sets the "rate_limit_5h" field.
|
||||
func (_u *APIKeyUpdateOne) SetRateLimit5h(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.ResetRateLimit5h()
|
||||
_u.mutation.SetRateLimit5h(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableRateLimit5h(v *float64) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRateLimit5h(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRateLimit5h adds value to the "rate_limit_5h" field.
|
||||
func (_u *APIKeyUpdateOne) AddRateLimit5h(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.AddRateLimit5h(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRateLimit1d sets the "rate_limit_1d" field.
|
||||
func (_u *APIKeyUpdateOne) SetRateLimit1d(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.ResetRateLimit1d()
|
||||
_u.mutation.SetRateLimit1d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableRateLimit1d(v *float64) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRateLimit1d(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRateLimit1d adds value to the "rate_limit_1d" field.
|
||||
func (_u *APIKeyUpdateOne) AddRateLimit1d(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.AddRateLimit1d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetRateLimit7d sets the "rate_limit_7d" field.
|
||||
func (_u *APIKeyUpdateOne) SetRateLimit7d(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.ResetRateLimit7d()
|
||||
_u.mutation.SetRateLimit7d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableRateLimit7d(v *float64) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetRateLimit7d(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddRateLimit7d adds value to the "rate_limit_7d" field.
|
||||
func (_u *APIKeyUpdateOne) AddRateLimit7d(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.AddRateLimit7d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsage5h sets the "usage_5h" field.
|
||||
func (_u *APIKeyUpdateOne) SetUsage5h(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.ResetUsage5h()
|
||||
_u.mutation.SetUsage5h(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableUsage5h(v *float64) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUsage5h(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsage5h adds value to the "usage_5h" field.
|
||||
func (_u *APIKeyUpdateOne) AddUsage5h(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.AddUsage5h(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsage1d sets the "usage_1d" field.
|
||||
func (_u *APIKeyUpdateOne) SetUsage1d(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.ResetUsage1d()
|
||||
_u.mutation.SetUsage1d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableUsage1d(v *float64) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUsage1d(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsage1d adds value to the "usage_1d" field.
|
||||
func (_u *APIKeyUpdateOne) AddUsage1d(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.AddUsage1d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUsage7d sets the "usage_7d" field.
|
||||
func (_u *APIKeyUpdateOne) SetUsage7d(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.ResetUsage7d()
|
||||
_u.mutation.SetUsage7d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableUsage7d(v *float64) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUsage7d(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddUsage7d adds value to the "usage_7d" field.
|
||||
func (_u *APIKeyUpdateOne) AddUsage7d(v float64) *APIKeyUpdateOne {
|
||||
_u.mutation.AddUsage7d(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWindow5hStart sets the "window_5h_start" field.
|
||||
func (_u *APIKeyUpdateOne) SetWindow5hStart(v time.Time) *APIKeyUpdateOne {
|
||||
_u.mutation.SetWindow5hStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWindow5hStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWindow5hStart clears the value of the "window_5h_start" field.
|
||||
func (_u *APIKeyUpdateOne) ClearWindow5hStart() *APIKeyUpdateOne {
|
||||
_u.mutation.ClearWindow5hStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWindow1dStart sets the "window_1d_start" field.
|
||||
func (_u *APIKeyUpdateOne) SetWindow1dStart(v time.Time) *APIKeyUpdateOne {
|
||||
_u.mutation.SetWindow1dStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWindow1dStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWindow1dStart clears the value of the "window_1d_start" field.
|
||||
func (_u *APIKeyUpdateOne) ClearWindow1dStart() *APIKeyUpdateOne {
|
||||
_u.mutation.ClearWindow1dStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWindow7dStart sets the "window_7d_start" field.
|
||||
func (_u *APIKeyUpdateOne) SetWindow7dStart(v time.Time) *APIKeyUpdateOne {
|
||||
_u.mutation.SetWindow7dStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil.
|
||||
func (_u *APIKeyUpdateOne) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWindow7dStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWindow7dStart clears the value of the "window_7d_start" field.
|
||||
func (_u *APIKeyUpdateOne) ClearWindow7dStart() *APIKeyUpdateOne {
|
||||
_u.mutation.ClearWindow7dStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
|
||||
return _u.SetUserID(v.ID)
|
||||
@@ -1033,6 +1459,60 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(apikey.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.RateLimit5h(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRateLimit5h(); ok {
|
||||
_spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RateLimit1d(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRateLimit1d(); ok {
|
||||
_spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.RateLimit7d(); ok {
|
||||
_spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedRateLimit7d(); ok {
|
||||
_spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Usage5h(); ok {
|
||||
_spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedUsage5h(); ok {
|
||||
_spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Usage1d(); ok {
|
||||
_spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedUsage1d(); ok {
|
||||
_spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Usage7d(); ok {
|
||||
_spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedUsage7d(); ok {
|
||||
_spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Window5hStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.Window5hStartCleared() {
|
||||
_spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Window1dStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.Window1dStartCleared() {
|
||||
_spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Window7dStart(); ok {
|
||||
_spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.Window7dStartCleared() {
|
||||
_spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
|
||||
@@ -24,6 +24,15 @@ var (
|
||||
{Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "expires_at", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "rate_limit_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "rate_limit_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "rate_limit_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "usage_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "usage_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "usage_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "window_5h_start", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "window_1d_start", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "window_7d_start", Type: field.TypeTime, Nullable: true},
|
||||
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "user_id", Type: field.TypeInt64},
|
||||
}
|
||||
@@ -35,13 +44,13 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "api_keys_groups_api_keys",
|
||||
Columns: []*schema.Column{APIKeysColumns[13]},
|
||||
Columns: []*schema.Column{APIKeysColumns[22]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "api_keys_users_api_keys",
|
||||
Columns: []*schema.Column{APIKeysColumns[14]},
|
||||
Columns: []*schema.Column{APIKeysColumns[23]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
@@ -50,12 +59,12 @@ var (
|
||||
{
|
||||
Name: "apikey_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{APIKeysColumns[14]},
|
||||
Columns: []*schema.Column{APIKeysColumns[23]},
|
||||
},
|
||||
{
|
||||
Name: "apikey_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{APIKeysColumns[13]},
|
||||
Columns: []*schema.Column{APIKeysColumns[22]},
|
||||
},
|
||||
{
|
||||
Name: "apikey_status",
|
||||
|
||||
@@ -91,6 +91,21 @@ type APIKeyMutation struct {
|
||||
quota_used *float64
|
||||
addquota_used *float64
|
||||
expires_at *time.Time
|
||||
rate_limit_5h *float64
|
||||
addrate_limit_5h *float64
|
||||
rate_limit_1d *float64
|
||||
addrate_limit_1d *float64
|
||||
rate_limit_7d *float64
|
||||
addrate_limit_7d *float64
|
||||
usage_5h *float64
|
||||
addusage_5h *float64
|
||||
usage_1d *float64
|
||||
addusage_1d *float64
|
||||
usage_7d *float64
|
||||
addusage_7d *float64
|
||||
window_5h_start *time.Time
|
||||
window_1d_start *time.Time
|
||||
window_7d_start *time.Time
|
||||
clearedFields map[string]struct{}
|
||||
user *int64
|
||||
cleareduser bool
|
||||
@@ -856,6 +871,489 @@ func (m *APIKeyMutation) ResetExpiresAt() {
|
||||
delete(m.clearedFields, apikey.FieldExpiresAt)
|
||||
}
|
||||
|
||||
// SetRateLimit5h sets the "rate_limit_5h" field.
|
||||
func (m *APIKeyMutation) SetRateLimit5h(f float64) {
|
||||
m.rate_limit_5h = &f
|
||||
m.addrate_limit_5h = nil
|
||||
}
|
||||
|
||||
// RateLimit5h returns the value of the "rate_limit_5h" field in the mutation.
|
||||
func (m *APIKeyMutation) RateLimit5h() (r float64, exists bool) {
|
||||
v := m.rate_limit_5h
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldRateLimit5h returns the old "rate_limit_5h" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldRateLimit5h(ctx context.Context) (v float64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldRateLimit5h is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldRateLimit5h requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldRateLimit5h: %w", err)
|
||||
}
|
||||
return oldValue.RateLimit5h, nil
|
||||
}
|
||||
|
||||
// AddRateLimit5h adds f to the "rate_limit_5h" field.
|
||||
func (m *APIKeyMutation) AddRateLimit5h(f float64) {
|
||||
if m.addrate_limit_5h != nil {
|
||||
*m.addrate_limit_5h += f
|
||||
} else {
|
||||
m.addrate_limit_5h = &f
|
||||
}
|
||||
}
|
||||
|
||||
// AddedRateLimit5h returns the value that was added to the "rate_limit_5h" field in this mutation.
|
||||
func (m *APIKeyMutation) AddedRateLimit5h() (r float64, exists bool) {
|
||||
v := m.addrate_limit_5h
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetRateLimit5h resets all changes to the "rate_limit_5h" field.
|
||||
func (m *APIKeyMutation) ResetRateLimit5h() {
|
||||
m.rate_limit_5h = nil
|
||||
m.addrate_limit_5h = nil
|
||||
}
|
||||
|
||||
// SetRateLimit1d sets the "rate_limit_1d" field.
|
||||
func (m *APIKeyMutation) SetRateLimit1d(f float64) {
|
||||
m.rate_limit_1d = &f
|
||||
m.addrate_limit_1d = nil
|
||||
}
|
||||
|
||||
// RateLimit1d returns the value of the "rate_limit_1d" field in the mutation.
|
||||
func (m *APIKeyMutation) RateLimit1d() (r float64, exists bool) {
|
||||
v := m.rate_limit_1d
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldRateLimit1d returns the old "rate_limit_1d" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldRateLimit1d(ctx context.Context) (v float64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldRateLimit1d is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldRateLimit1d requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldRateLimit1d: %w", err)
|
||||
}
|
||||
return oldValue.RateLimit1d, nil
|
||||
}
|
||||
|
||||
// AddRateLimit1d adds f to the "rate_limit_1d" field.
|
||||
func (m *APIKeyMutation) AddRateLimit1d(f float64) {
|
||||
if m.addrate_limit_1d != nil {
|
||||
*m.addrate_limit_1d += f
|
||||
} else {
|
||||
m.addrate_limit_1d = &f
|
||||
}
|
||||
}
|
||||
|
||||
// AddedRateLimit1d returns the value that was added to the "rate_limit_1d" field in this mutation.
|
||||
func (m *APIKeyMutation) AddedRateLimit1d() (r float64, exists bool) {
|
||||
v := m.addrate_limit_1d
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetRateLimit1d resets all changes to the "rate_limit_1d" field.
|
||||
func (m *APIKeyMutation) ResetRateLimit1d() {
|
||||
m.rate_limit_1d = nil
|
||||
m.addrate_limit_1d = nil
|
||||
}
|
||||
|
||||
// SetRateLimit7d sets the "rate_limit_7d" field.
|
||||
func (m *APIKeyMutation) SetRateLimit7d(f float64) {
|
||||
m.rate_limit_7d = &f
|
||||
m.addrate_limit_7d = nil
|
||||
}
|
||||
|
||||
// RateLimit7d returns the value of the "rate_limit_7d" field in the mutation.
|
||||
func (m *APIKeyMutation) RateLimit7d() (r float64, exists bool) {
|
||||
v := m.rate_limit_7d
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldRateLimit7d returns the old "rate_limit_7d" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldRateLimit7d(ctx context.Context) (v float64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldRateLimit7d is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldRateLimit7d requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldRateLimit7d: %w", err)
|
||||
}
|
||||
return oldValue.RateLimit7d, nil
|
||||
}
|
||||
|
||||
// AddRateLimit7d adds f to the "rate_limit_7d" field.
|
||||
func (m *APIKeyMutation) AddRateLimit7d(f float64) {
|
||||
if m.addrate_limit_7d != nil {
|
||||
*m.addrate_limit_7d += f
|
||||
} else {
|
||||
m.addrate_limit_7d = &f
|
||||
}
|
||||
}
|
||||
|
||||
// AddedRateLimit7d returns the value that was added to the "rate_limit_7d" field in this mutation.
|
||||
func (m *APIKeyMutation) AddedRateLimit7d() (r float64, exists bool) {
|
||||
v := m.addrate_limit_7d
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetRateLimit7d resets all changes to the "rate_limit_7d" field.
|
||||
func (m *APIKeyMutation) ResetRateLimit7d() {
|
||||
m.rate_limit_7d = nil
|
||||
m.addrate_limit_7d = nil
|
||||
}
|
||||
|
||||
// SetUsage5h sets the "usage_5h" field.
|
||||
func (m *APIKeyMutation) SetUsage5h(f float64) {
|
||||
m.usage_5h = &f
|
||||
m.addusage_5h = nil
|
||||
}
|
||||
|
||||
// Usage5h returns the value of the "usage_5h" field in the mutation.
|
||||
func (m *APIKeyMutation) Usage5h() (r float64, exists bool) {
|
||||
v := m.usage_5h
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldUsage5h returns the old "usage_5h" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldUsage5h(ctx context.Context) (v float64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldUsage5h is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldUsage5h requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldUsage5h: %w", err)
|
||||
}
|
||||
return oldValue.Usage5h, nil
|
||||
}
|
||||
|
||||
// AddUsage5h adds f to the "usage_5h" field.
|
||||
func (m *APIKeyMutation) AddUsage5h(f float64) {
|
||||
if m.addusage_5h != nil {
|
||||
*m.addusage_5h += f
|
||||
} else {
|
||||
m.addusage_5h = &f
|
||||
}
|
||||
}
|
||||
|
||||
// AddedUsage5h returns the value that was added to the "usage_5h" field in this mutation.
|
||||
func (m *APIKeyMutation) AddedUsage5h() (r float64, exists bool) {
|
||||
v := m.addusage_5h
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetUsage5h resets all changes to the "usage_5h" field.
|
||||
func (m *APIKeyMutation) ResetUsage5h() {
|
||||
m.usage_5h = nil
|
||||
m.addusage_5h = nil
|
||||
}
|
||||
|
||||
// SetUsage1d sets the "usage_1d" field.
|
||||
func (m *APIKeyMutation) SetUsage1d(f float64) {
|
||||
m.usage_1d = &f
|
||||
m.addusage_1d = nil
|
||||
}
|
||||
|
||||
// Usage1d returns the value of the "usage_1d" field in the mutation.
|
||||
func (m *APIKeyMutation) Usage1d() (r float64, exists bool) {
|
||||
v := m.usage_1d
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldUsage1d returns the old "usage_1d" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldUsage1d(ctx context.Context) (v float64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldUsage1d is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldUsage1d requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldUsage1d: %w", err)
|
||||
}
|
||||
return oldValue.Usage1d, nil
|
||||
}
|
||||
|
||||
// AddUsage1d adds f to the "usage_1d" field.
|
||||
func (m *APIKeyMutation) AddUsage1d(f float64) {
|
||||
if m.addusage_1d != nil {
|
||||
*m.addusage_1d += f
|
||||
} else {
|
||||
m.addusage_1d = &f
|
||||
}
|
||||
}
|
||||
|
||||
// AddedUsage1d returns the value that was added to the "usage_1d" field in this mutation.
|
||||
func (m *APIKeyMutation) AddedUsage1d() (r float64, exists bool) {
|
||||
v := m.addusage_1d
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetUsage1d resets all changes to the "usage_1d" field.
|
||||
func (m *APIKeyMutation) ResetUsage1d() {
|
||||
m.usage_1d = nil
|
||||
m.addusage_1d = nil
|
||||
}
|
||||
|
||||
// SetUsage7d sets the "usage_7d" field.
|
||||
func (m *APIKeyMutation) SetUsage7d(f float64) {
|
||||
m.usage_7d = &f
|
||||
m.addusage_7d = nil
|
||||
}
|
||||
|
||||
// Usage7d returns the value of the "usage_7d" field in the mutation.
|
||||
func (m *APIKeyMutation) Usage7d() (r float64, exists bool) {
|
||||
v := m.usage_7d
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldUsage7d returns the old "usage_7d" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldUsage7d(ctx context.Context) (v float64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldUsage7d is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldUsage7d requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldUsage7d: %w", err)
|
||||
}
|
||||
return oldValue.Usage7d, nil
|
||||
}
|
||||
|
||||
// AddUsage7d adds f to the "usage_7d" field.
|
||||
func (m *APIKeyMutation) AddUsage7d(f float64) {
|
||||
if m.addusage_7d != nil {
|
||||
*m.addusage_7d += f
|
||||
} else {
|
||||
m.addusage_7d = &f
|
||||
}
|
||||
}
|
||||
|
||||
// AddedUsage7d returns the value that was added to the "usage_7d" field in this mutation.
|
||||
func (m *APIKeyMutation) AddedUsage7d() (r float64, exists bool) {
|
||||
v := m.addusage_7d
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ResetUsage7d resets all changes to the "usage_7d" field.
|
||||
func (m *APIKeyMutation) ResetUsage7d() {
|
||||
m.usage_7d = nil
|
||||
m.addusage_7d = nil
|
||||
}
|
||||
|
||||
// SetWindow5hStart sets the "window_5h_start" field.
|
||||
func (m *APIKeyMutation) SetWindow5hStart(t time.Time) {
|
||||
m.window_5h_start = &t
|
||||
}
|
||||
|
||||
// Window5hStart returns the value of the "window_5h_start" field in the mutation.
|
||||
func (m *APIKeyMutation) Window5hStart() (r time.Time, exists bool) {
|
||||
v := m.window_5h_start
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldWindow5hStart returns the old "window_5h_start" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldWindow5hStart(ctx context.Context) (v *time.Time, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldWindow5hStart is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldWindow5hStart requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldWindow5hStart: %w", err)
|
||||
}
|
||||
return oldValue.Window5hStart, nil
|
||||
}
|
||||
|
||||
// ClearWindow5hStart clears the value of the "window_5h_start" field.
|
||||
func (m *APIKeyMutation) ClearWindow5hStart() {
|
||||
m.window_5h_start = nil
|
||||
m.clearedFields[apikey.FieldWindow5hStart] = struct{}{}
|
||||
}
|
||||
|
||||
// Window5hStartCleared returns if the "window_5h_start" field was cleared in this mutation.
|
||||
func (m *APIKeyMutation) Window5hStartCleared() bool {
|
||||
_, ok := m.clearedFields[apikey.FieldWindow5hStart]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetWindow5hStart resets all changes to the "window_5h_start" field.
|
||||
func (m *APIKeyMutation) ResetWindow5hStart() {
|
||||
m.window_5h_start = nil
|
||||
delete(m.clearedFields, apikey.FieldWindow5hStart)
|
||||
}
|
||||
|
||||
// SetWindow1dStart sets the "window_1d_start" field.
|
||||
func (m *APIKeyMutation) SetWindow1dStart(t time.Time) {
|
||||
m.window_1d_start = &t
|
||||
}
|
||||
|
||||
// Window1dStart returns the value of the "window_1d_start" field in the mutation.
|
||||
func (m *APIKeyMutation) Window1dStart() (r time.Time, exists bool) {
|
||||
v := m.window_1d_start
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldWindow1dStart returns the old "window_1d_start" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldWindow1dStart(ctx context.Context) (v *time.Time, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldWindow1dStart is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldWindow1dStart requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldWindow1dStart: %w", err)
|
||||
}
|
||||
return oldValue.Window1dStart, nil
|
||||
}
|
||||
|
||||
// ClearWindow1dStart clears the value of the "window_1d_start" field.
|
||||
func (m *APIKeyMutation) ClearWindow1dStart() {
|
||||
m.window_1d_start = nil
|
||||
m.clearedFields[apikey.FieldWindow1dStart] = struct{}{}
|
||||
}
|
||||
|
||||
// Window1dStartCleared returns if the "window_1d_start" field was cleared in this mutation.
|
||||
func (m *APIKeyMutation) Window1dStartCleared() bool {
|
||||
_, ok := m.clearedFields[apikey.FieldWindow1dStart]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetWindow1dStart resets all changes to the "window_1d_start" field.
|
||||
func (m *APIKeyMutation) ResetWindow1dStart() {
|
||||
m.window_1d_start = nil
|
||||
delete(m.clearedFields, apikey.FieldWindow1dStart)
|
||||
}
|
||||
|
||||
// SetWindow7dStart sets the "window_7d_start" field.
|
||||
func (m *APIKeyMutation) SetWindow7dStart(t time.Time) {
|
||||
m.window_7d_start = &t
|
||||
}
|
||||
|
||||
// Window7dStart returns the value of the "window_7d_start" field in the mutation.
|
||||
func (m *APIKeyMutation) Window7dStart() (r time.Time, exists bool) {
|
||||
v := m.window_7d_start
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldWindow7dStart returns the old "window_7d_start" field's value of the APIKey entity.
|
||||
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *APIKeyMutation) OldWindow7dStart(ctx context.Context) (v *time.Time, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldWindow7dStart is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldWindow7dStart requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldWindow7dStart: %w", err)
|
||||
}
|
||||
return oldValue.Window7dStart, nil
|
||||
}
|
||||
|
||||
// ClearWindow7dStart clears the value of the "window_7d_start" field.
|
||||
func (m *APIKeyMutation) ClearWindow7dStart() {
|
||||
m.window_7d_start = nil
|
||||
m.clearedFields[apikey.FieldWindow7dStart] = struct{}{}
|
||||
}
|
||||
|
||||
// Window7dStartCleared returns if the "window_7d_start" field was cleared in this mutation.
|
||||
func (m *APIKeyMutation) Window7dStartCleared() bool {
|
||||
_, ok := m.clearedFields[apikey.FieldWindow7dStart]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetWindow7dStart resets all changes to the "window_7d_start" field.
|
||||
func (m *APIKeyMutation) ResetWindow7dStart() {
|
||||
m.window_7d_start = nil
|
||||
delete(m.clearedFields, apikey.FieldWindow7dStart)
|
||||
}
|
||||
|
||||
// ClearUser clears the "user" edge to the User entity.
|
||||
func (m *APIKeyMutation) ClearUser() {
|
||||
m.cleareduser = true
|
||||
@@ -998,7 +1496,7 @@ func (m *APIKeyMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *APIKeyMutation) Fields() []string {
|
||||
fields := make([]string, 0, 14)
|
||||
fields := make([]string, 0, 23)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, apikey.FieldCreatedAt)
|
||||
}
|
||||
@@ -1041,6 +1539,33 @@ func (m *APIKeyMutation) Fields() []string {
|
||||
if m.expires_at != nil {
|
||||
fields = append(fields, apikey.FieldExpiresAt)
|
||||
}
|
||||
if m.rate_limit_5h != nil {
|
||||
fields = append(fields, apikey.FieldRateLimit5h)
|
||||
}
|
||||
if m.rate_limit_1d != nil {
|
||||
fields = append(fields, apikey.FieldRateLimit1d)
|
||||
}
|
||||
if m.rate_limit_7d != nil {
|
||||
fields = append(fields, apikey.FieldRateLimit7d)
|
||||
}
|
||||
if m.usage_5h != nil {
|
||||
fields = append(fields, apikey.FieldUsage5h)
|
||||
}
|
||||
if m.usage_1d != nil {
|
||||
fields = append(fields, apikey.FieldUsage1d)
|
||||
}
|
||||
if m.usage_7d != nil {
|
||||
fields = append(fields, apikey.FieldUsage7d)
|
||||
}
|
||||
if m.window_5h_start != nil {
|
||||
fields = append(fields, apikey.FieldWindow5hStart)
|
||||
}
|
||||
if m.window_1d_start != nil {
|
||||
fields = append(fields, apikey.FieldWindow1dStart)
|
||||
}
|
||||
if m.window_7d_start != nil {
|
||||
fields = append(fields, apikey.FieldWindow7dStart)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -1077,6 +1602,24 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.QuotaUsed()
|
||||
case apikey.FieldExpiresAt:
|
||||
return m.ExpiresAt()
|
||||
case apikey.FieldRateLimit5h:
|
||||
return m.RateLimit5h()
|
||||
case apikey.FieldRateLimit1d:
|
||||
return m.RateLimit1d()
|
||||
case apikey.FieldRateLimit7d:
|
||||
return m.RateLimit7d()
|
||||
case apikey.FieldUsage5h:
|
||||
return m.Usage5h()
|
||||
case apikey.FieldUsage1d:
|
||||
return m.Usage1d()
|
||||
case apikey.FieldUsage7d:
|
||||
return m.Usage7d()
|
||||
case apikey.FieldWindow5hStart:
|
||||
return m.Window5hStart()
|
||||
case apikey.FieldWindow1dStart:
|
||||
return m.Window1dStart()
|
||||
case apikey.FieldWindow7dStart:
|
||||
return m.Window7dStart()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -1114,6 +1657,24 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value,
|
||||
return m.OldQuotaUsed(ctx)
|
||||
case apikey.FieldExpiresAt:
|
||||
return m.OldExpiresAt(ctx)
|
||||
case apikey.FieldRateLimit5h:
|
||||
return m.OldRateLimit5h(ctx)
|
||||
case apikey.FieldRateLimit1d:
|
||||
return m.OldRateLimit1d(ctx)
|
||||
case apikey.FieldRateLimit7d:
|
||||
return m.OldRateLimit7d(ctx)
|
||||
case apikey.FieldUsage5h:
|
||||
return m.OldUsage5h(ctx)
|
||||
case apikey.FieldUsage1d:
|
||||
return m.OldUsage1d(ctx)
|
||||
case apikey.FieldUsage7d:
|
||||
return m.OldUsage7d(ctx)
|
||||
case apikey.FieldWindow5hStart:
|
||||
return m.OldWindow5hStart(ctx)
|
||||
case apikey.FieldWindow1dStart:
|
||||
return m.OldWindow1dStart(ctx)
|
||||
case apikey.FieldWindow7dStart:
|
||||
return m.OldWindow7dStart(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown APIKey field %s", name)
|
||||
}
|
||||
@@ -1221,6 +1782,69 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetExpiresAt(v)
|
||||
return nil
|
||||
case apikey.FieldRateLimit5h:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetRateLimit5h(v)
|
||||
return nil
|
||||
case apikey.FieldRateLimit1d:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetRateLimit1d(v)
|
||||
return nil
|
||||
case apikey.FieldRateLimit7d:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetRateLimit7d(v)
|
||||
return nil
|
||||
case apikey.FieldUsage5h:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetUsage5h(v)
|
||||
return nil
|
||||
case apikey.FieldUsage1d:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetUsage1d(v)
|
||||
return nil
|
||||
case apikey.FieldUsage7d:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetUsage7d(v)
|
||||
return nil
|
||||
case apikey.FieldWindow5hStart:
|
||||
v, ok := value.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetWindow5hStart(v)
|
||||
return nil
|
||||
case apikey.FieldWindow1dStart:
|
||||
v, ok := value.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetWindow1dStart(v)
|
||||
return nil
|
||||
case apikey.FieldWindow7dStart:
|
||||
v, ok := value.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetWindow7dStart(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown APIKey field %s", name)
|
||||
}
|
||||
@@ -1235,6 +1859,24 @@ func (m *APIKeyMutation) AddedFields() []string {
|
||||
if m.addquota_used != nil {
|
||||
fields = append(fields, apikey.FieldQuotaUsed)
|
||||
}
|
||||
if m.addrate_limit_5h != nil {
|
||||
fields = append(fields, apikey.FieldRateLimit5h)
|
||||
}
|
||||
if m.addrate_limit_1d != nil {
|
||||
fields = append(fields, apikey.FieldRateLimit1d)
|
||||
}
|
||||
if m.addrate_limit_7d != nil {
|
||||
fields = append(fields, apikey.FieldRateLimit7d)
|
||||
}
|
||||
if m.addusage_5h != nil {
|
||||
fields = append(fields, apikey.FieldUsage5h)
|
||||
}
|
||||
if m.addusage_1d != nil {
|
||||
fields = append(fields, apikey.FieldUsage1d)
|
||||
}
|
||||
if m.addusage_7d != nil {
|
||||
fields = append(fields, apikey.FieldUsage7d)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -1247,6 +1889,18 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) {
|
||||
return m.AddedQuota()
|
||||
case apikey.FieldQuotaUsed:
|
||||
return m.AddedQuotaUsed()
|
||||
case apikey.FieldRateLimit5h:
|
||||
return m.AddedRateLimit5h()
|
||||
case apikey.FieldRateLimit1d:
|
||||
return m.AddedRateLimit1d()
|
||||
case apikey.FieldRateLimit7d:
|
||||
return m.AddedRateLimit7d()
|
||||
case apikey.FieldUsage5h:
|
||||
return m.AddedUsage5h()
|
||||
case apikey.FieldUsage1d:
|
||||
return m.AddedUsage1d()
|
||||
case apikey.FieldUsage7d:
|
||||
return m.AddedUsage7d()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -1270,6 +1924,48 @@ func (m *APIKeyMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddQuotaUsed(v)
|
||||
return nil
|
||||
case apikey.FieldRateLimit5h:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddRateLimit5h(v)
|
||||
return nil
|
||||
case apikey.FieldRateLimit1d:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddRateLimit1d(v)
|
||||
return nil
|
||||
case apikey.FieldRateLimit7d:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddRateLimit7d(v)
|
||||
return nil
|
||||
case apikey.FieldUsage5h:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddUsage5h(v)
|
||||
return nil
|
||||
case apikey.FieldUsage1d:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddUsage1d(v)
|
||||
return nil
|
||||
case apikey.FieldUsage7d:
|
||||
v, ok := value.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddUsage7d(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown APIKey numeric field %s", name)
|
||||
}
|
||||
@@ -1296,6 +1992,15 @@ func (m *APIKeyMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(apikey.FieldExpiresAt) {
|
||||
fields = append(fields, apikey.FieldExpiresAt)
|
||||
}
|
||||
if m.FieldCleared(apikey.FieldWindow5hStart) {
|
||||
fields = append(fields, apikey.FieldWindow5hStart)
|
||||
}
|
||||
if m.FieldCleared(apikey.FieldWindow1dStart) {
|
||||
fields = append(fields, apikey.FieldWindow1dStart)
|
||||
}
|
||||
if m.FieldCleared(apikey.FieldWindow7dStart) {
|
||||
fields = append(fields, apikey.FieldWindow7dStart)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -1328,6 +2033,15 @@ func (m *APIKeyMutation) ClearField(name string) error {
|
||||
case apikey.FieldExpiresAt:
|
||||
m.ClearExpiresAt()
|
||||
return nil
|
||||
case apikey.FieldWindow5hStart:
|
||||
m.ClearWindow5hStart()
|
||||
return nil
|
||||
case apikey.FieldWindow1dStart:
|
||||
m.ClearWindow1dStart()
|
||||
return nil
|
||||
case apikey.FieldWindow7dStart:
|
||||
m.ClearWindow7dStart()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown APIKey nullable field %s", name)
|
||||
}
|
||||
@@ -1378,6 +2092,33 @@ func (m *APIKeyMutation) ResetField(name string) error {
|
||||
case apikey.FieldExpiresAt:
|
||||
m.ResetExpiresAt()
|
||||
return nil
|
||||
case apikey.FieldRateLimit5h:
|
||||
m.ResetRateLimit5h()
|
||||
return nil
|
||||
case apikey.FieldRateLimit1d:
|
||||
m.ResetRateLimit1d()
|
||||
return nil
|
||||
case apikey.FieldRateLimit7d:
|
||||
m.ResetRateLimit7d()
|
||||
return nil
|
||||
case apikey.FieldUsage5h:
|
||||
m.ResetUsage5h()
|
||||
return nil
|
||||
case apikey.FieldUsage1d:
|
||||
m.ResetUsage1d()
|
||||
return nil
|
||||
case apikey.FieldUsage7d:
|
||||
m.ResetUsage7d()
|
||||
return nil
|
||||
case apikey.FieldWindow5hStart:
|
||||
m.ResetWindow5hStart()
|
||||
return nil
|
||||
case apikey.FieldWindow1dStart:
|
||||
m.ResetWindow1dStart()
|
||||
return nil
|
||||
case apikey.FieldWindow7dStart:
|
||||
m.ResetWindow7dStart()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown APIKey field %s", name)
|
||||
}
|
||||
|
||||
@@ -102,6 +102,30 @@ func init() {
|
||||
apikeyDescQuotaUsed := apikeyFields[9].Descriptor()
|
||||
// apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field.
|
||||
apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64)
|
||||
// apikeyDescRateLimit5h is the schema descriptor for rate_limit_5h field.
|
||||
apikeyDescRateLimit5h := apikeyFields[11].Descriptor()
|
||||
// apikey.DefaultRateLimit5h holds the default value on creation for the rate_limit_5h field.
|
||||
apikey.DefaultRateLimit5h = apikeyDescRateLimit5h.Default.(float64)
|
||||
// apikeyDescRateLimit1d is the schema descriptor for rate_limit_1d field.
|
||||
apikeyDescRateLimit1d := apikeyFields[12].Descriptor()
|
||||
// apikey.DefaultRateLimit1d holds the default value on creation for the rate_limit_1d field.
|
||||
apikey.DefaultRateLimit1d = apikeyDescRateLimit1d.Default.(float64)
|
||||
// apikeyDescRateLimit7d is the schema descriptor for rate_limit_7d field.
|
||||
apikeyDescRateLimit7d := apikeyFields[13].Descriptor()
|
||||
// apikey.DefaultRateLimit7d holds the default value on creation for the rate_limit_7d field.
|
||||
apikey.DefaultRateLimit7d = apikeyDescRateLimit7d.Default.(float64)
|
||||
// apikeyDescUsage5h is the schema descriptor for usage_5h field.
|
||||
apikeyDescUsage5h := apikeyFields[14].Descriptor()
|
||||
// apikey.DefaultUsage5h holds the default value on creation for the usage_5h field.
|
||||
apikey.DefaultUsage5h = apikeyDescUsage5h.Default.(float64)
|
||||
// apikeyDescUsage1d is the schema descriptor for usage_1d field.
|
||||
apikeyDescUsage1d := apikeyFields[15].Descriptor()
|
||||
// apikey.DefaultUsage1d holds the default value on creation for the usage_1d field.
|
||||
apikey.DefaultUsage1d = apikeyDescUsage1d.Default.(float64)
|
||||
// apikeyDescUsage7d is the schema descriptor for usage_7d field.
|
||||
apikeyDescUsage7d := apikeyFields[16].Descriptor()
|
||||
// apikey.DefaultUsage7d holds the default value on creation for the usage_7d field.
|
||||
apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64)
|
||||
accountMixin := schema.Account{}.Mixin()
|
||||
accountMixinHooks1 := accountMixin[1].Hooks()
|
||||
account.Hooks[0] = accountMixinHooks1[0]
|
||||
|
||||
@@ -74,6 +74,47 @@ func (APIKey) Fields() []ent.Field {
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("Expiration time for this API key (null = never expires)"),
|
||||
|
||||
// ========== Rate limit fields ==========
|
||||
// Rate limit configuration (0 = unlimited)
|
||||
field.Float("rate_limit_5h").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Default(0).
|
||||
Comment("Rate limit in USD per 5 hours (0 = unlimited)"),
|
||||
field.Float("rate_limit_1d").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Default(0).
|
||||
Comment("Rate limit in USD per day (0 = unlimited)"),
|
||||
field.Float("rate_limit_7d").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Default(0).
|
||||
Comment("Rate limit in USD per 7 days (0 = unlimited)"),
|
||||
// Rate limit usage tracking
|
||||
field.Float("usage_5h").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Default(0).
|
||||
Comment("Used amount in USD for the current 5h window"),
|
||||
field.Float("usage_1d").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Default(0).
|
||||
Comment("Used amount in USD for the current 1d window"),
|
||||
field.Float("usage_7d").
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||
Default(0).
|
||||
Comment("Used amount in USD for the current 7d window"),
|
||||
// Window start times
|
||||
field.Time("window_5h_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("Start time of the current 5h rate limit window"),
|
||||
field.Time("window_1d_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("Start time of the current 1d rate limit window"),
|
||||
field.Time("window_7d_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("Start time of the current 7d rate limit window"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -180,8 +180,6 @@ require (
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
@@ -516,7 +516,7 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
||||
// IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough)
|
||||
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||
// Enabled: 全局总开关(默认 true)
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
@@ -872,7 +872,8 @@ type DefaultConfig struct {
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟)
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheConfig API Key 认证缓存配置
|
||||
@@ -1226,7 +1227,7 @@ func setDefaults() {
|
||||
|
||||
// Ops (vNext)
|
||||
viper.SetDefault("ops.enabled", true)
|
||||
viper.SetDefault("ops.use_preaggregated_tables", false)
|
||||
viper.SetDefault("ops.use_preaggregated_tables", true)
|
||||
viper.SetDefault("ops.cleanup.enabled", true)
|
||||
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
|
||||
// Retention days: vNext defaults to 30 days across ops datasets.
|
||||
@@ -1260,6 +1261,7 @@ func setDefaults() {
|
||||
|
||||
// RateLimit
|
||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
|
||||
@@ -1333,7 +1335,7 @@ func setDefaults() {
|
||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
|
||||
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||
@@ -2041,9 +2043,11 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||
switch mode {
|
||||
case "off", "shared", "dedicated":
|
||||
case "off", "ctx_pool", "passthrough":
|
||||
case "shared", "dedicated":
|
||||
slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
|
||||
default:
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
|
||||
}
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||
|
||||
@@ -153,8 +153,8 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
||||
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
||||
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||
},
|
||||
{
|
||||
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
||||
name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||
},
|
||||
|
||||
@@ -217,6 +217,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
lite := parseBoolQueryWithDefault(c.Query("lite"), false)
|
||||
|
||||
var groupID int64
|
||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||
@@ -235,10 +236,16 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
accountIDs[i] = acc.ID
|
||||
}
|
||||
|
||||
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
|
||||
if err != nil {
|
||||
// Log error but don't fail the request, just use 0 for all
|
||||
concurrencyCounts = make(map[int64]int)
|
||||
concurrencyCounts := make(map[int64]int)
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
var rpmCounts map[int64]int
|
||||
|
||||
// 始终获取并发数(Redis ZCARD,极低开销)
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
}
|
||||
}
|
||||
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
@@ -262,12 +269,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 并行获取窗口费用、活跃会话数和 RPM 计数
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
var rpmCounts map[int64]int
|
||||
|
||||
// 获取 RPM 计数(批量查询)
|
||||
// 始终获取 RPM 计数(Redis GET,极低开销)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
@@ -275,7 +277,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
// 始终获取活跃会话数(Redis ZCARD,低开销)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
@@ -283,8 +285,8 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取窗口费用(并行查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
// 仅非 lite 模式获取窗口费用(PostgreSQL 聚合查询,高开销)
|
||||
if !lite && len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
@@ -344,7 +346,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
result[i] = item
|
||||
}
|
||||
|
||||
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
|
||||
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite)
|
||||
if etag != "" {
|
||||
c.Header("ETag", etag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
@@ -362,6 +364,7 @@ func buildAccountsListETag(
|
||||
total int64,
|
||||
page, pageSize int,
|
||||
platform, accountType, status, search string,
|
||||
lite bool,
|
||||
) string {
|
||||
payload := struct {
|
||||
Total int64 `json:"total"`
|
||||
@@ -371,6 +374,7 @@ func buildAccountsListETag(
|
||||
AccountType string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Search string `json:"search"`
|
||||
Lite bool `json:"lite"`
|
||||
Items []AccountWithConcurrency `json:"items"`
|
||||
}{
|
||||
Total: total,
|
||||
@@ -380,6 +384,7 @@ func buildAccountsListETag(
|
||||
AccountType: accountType,
|
||||
Status: status,
|
||||
Search: search,
|
||||
Lite: lite,
|
||||
Items: items,
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
@@ -1398,18 +1403,41 @@ func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
accountIDs := normalizeInt64IDList(req.AccountIDs)
|
||||
if len(accountIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
|
||||
cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs)
|
||||
if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
payload := gin.H{"stats": stats}
|
||||
cached := accountTodayStatsBatchCache.Set(cacheKey, payload)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// SetSchedulableRequest represents the request body for setting schedulable status
|
||||
|
||||
25
backend/internal/handler/admin/account_today_stats_cache.go
Normal file
25
backend/internal/handler/admin/account_today_stats_cache.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
|
||||
if len(accountIDs) == 0 {
|
||||
return "accounts_today_stats_empty"
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(accountIDs) * 6)
|
||||
_, _ = b.WriteString("accounts_today_stats:")
|
||||
for i, id := range accountIDs {
|
||||
if i > 0 {
|
||||
_ = b.WriteByte(',')
|
||||
}
|
||||
_, _ = b.WriteString(strconv.FormatInt(id, 10))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -460,6 +461,9 @@ type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
@@ -469,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
userIDs := normalizeInt64IDList(req.UserIDs)
|
||||
if len(userIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
}{
|
||||
UserIDs: userIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
payload := gin.H{"stats": stats}
|
||||
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
|
||||
@@ -497,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids"`
|
||||
}{
|
||||
APIKeyIDs: apiKeyIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
payload := gin.H{"stats": stats}
|
||||
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
292
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
Normal file
292
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
type dashboardSnapshotV2Stats struct {
|
||||
usagestats.DashboardStats
|
||||
Uptime int64 `json:"uptime"`
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2Response struct {
|
||||
GeneratedAt string `json:"generated_at"`
|
||||
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
Granularity string `json:"granularity"`
|
||||
|
||||
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
|
||||
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
|
||||
Models []usagestats.ModelStat `json:"models,omitempty"`
|
||||
Groups []usagestats.GroupStat `json:"groups,omitempty"`
|
||||
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2Filters struct {
|
||||
UserID int64
|
||||
APIKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
RequestType *int16
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2CacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
IncludeStats bool `json:"include_stats"`
|
||||
IncludeTrend bool `json:"include_trend"`
|
||||
IncludeModels bool `json:"include_models"`
|
||||
IncludeGroups bool `json:"include_groups"`
|
||||
IncludeUsersTrend bool `json:"include_users_trend"`
|
||||
UsersTrendLimit int `json:"users_trend_limit"`
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
|
||||
if granularity != "hour" {
|
||||
granularity = "day"
|
||||
}
|
||||
|
||||
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
|
||||
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
|
||||
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
|
||||
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
|
||||
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
|
||||
usersTrendLimit := 12
|
||||
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
|
||||
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
|
||||
usersTrendLimit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
filters, err := parseDashboardSnapshotV2Filters(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: filters.UserID,
|
||||
APIKeyID: filters.APIKeyID,
|
||||
AccountID: filters.AccountID,
|
||||
GroupID: filters.GroupID,
|
||||
Model: filters.Model,
|
||||
RequestType: filters.RequestType,
|
||||
Stream: filters.Stream,
|
||||
BillingType: filters.BillingType,
|
||||
IncludeStats: includeStats,
|
||||
IncludeTrend: includeTrend,
|
||||
IncludeModels: includeModels,
|
||||
IncludeGroups: includeGroups,
|
||||
IncludeUsersTrend: includeUsersTrend,
|
||||
UsersTrendLimit: usersTrendLimit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
resp := &dashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
StartDate: startTime.Format("2006-01-02"),
|
||||
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
Granularity: granularity,
|
||||
}
|
||||
|
||||
if includeStats {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
}
|
||||
resp.Stats = &dashboardSnapshotV2Stats{
|
||||
DashboardStats: *stats,
|
||||
Uptime: int64(time.Since(h.startTime).Seconds()),
|
||||
}
|
||||
}
|
||||
|
||||
if includeTrend {
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.Model,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
resp.Trend = trend
|
||||
}
|
||||
|
||||
if includeModels {
|
||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
resp.Models = models
|
||||
}
|
||||
|
||||
if includeGroups {
|
||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
resp.Groups = groups
|
||||
}
|
||||
|
||||
if includeUsersTrend {
|
||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
usersTrendLimit,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
resp.UsersTrend = usersTrend
|
||||
}
|
||||
|
||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||
filters := &dashboardSnapshotV2Filters{
|
||||
Model: strings.TrimSpace(c.Query("model")),
|
||||
}
|
||||
|
||||
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.UserID = id
|
||||
}
|
||||
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.APIKeyID = id
|
||||
}
|
||||
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
|
||||
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.AccountID = id
|
||||
}
|
||||
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
|
||||
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.GroupID = id
|
||||
}
|
||||
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value := int16(parsed)
|
||||
filters.RequestType = &value
|
||||
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
|
||||
streamVal, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.Stream = &streamVal
|
||||
}
|
||||
|
||||
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
|
||||
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bt := int8(v)
|
||||
filters.BillingType = &bt
|
||||
}
|
||||
|
||||
return filters, nil
|
||||
}
|
||||
25
backend/internal/handler/admin/id_list_utils.go
Normal file
25
backend/internal/handler/admin/id_list_utils.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package admin
|
||||
|
||||
import "sort"
|
||||
|
||||
func normalizeInt64IDList(ids []int64) []int64 {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]int64, 0, len(ids))
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
|
||||
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
|
||||
return out
|
||||
}
|
||||
57
backend/internal/handler/admin/id_list_utils_test.go
Normal file
57
backend/internal/handler/admin/id_list_utils_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizeInt64IDList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in []int64
|
||||
want []int64
|
||||
}{
|
||||
{"nil input", nil, nil},
|
||||
{"empty input", []int64{}, nil},
|
||||
{"single element", []int64{5}, []int64{5}},
|
||||
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
|
||||
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
|
||||
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
|
||||
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
|
||||
{"all invalid", []int64{0, -1, -2}, []int64{}},
|
||||
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := normalizeInt64IDList(tc.in)
|
||||
if tc.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.Equal(t, tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []int64
|
||||
want string
|
||||
}{
|
||||
{"empty", nil, "accounts_today_stats_empty"},
|
||||
{"single", []int64{42}, "accounts_today_stats:42"},
|
||||
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
145
backend/internal/handler/admin/ops_snapshot_v2_handler.go
Normal file
145
backend/internal/handler/admin/ops_snapshot_v2_handler.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
type opsDashboardSnapshotV2Response struct {
|
||||
GeneratedAt string `json:"generated_at"`
|
||||
|
||||
Overview *service.OpsDashboardOverview `json:"overview"`
|
||||
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
|
||||
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
|
||||
}
|
||||
|
||||
type opsDashboardSnapshotV2CacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
QueryMode service.OpsQueryMode `json:"mode"`
|
||||
BucketSecond int `json:"bucket_second"`
|
||||
}
|
||||
|
||||
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
|
||||
// GET /api/v1/admin/ops/dashboard/snapshot-v2
|
||||
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsDashboardFilter{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Platform: strings.TrimSpace(c.Query("platform")),
|
||||
QueryMode: parseOpsQueryMode(c),
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
filter.GroupID = &id
|
||||
}
|
||||
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
|
||||
|
||||
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Platform: filter.Platform,
|
||||
GroupID: filter.GroupID,
|
||||
QueryMode: filter.QueryMode,
|
||||
BucketSecond: bucketSeconds,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
overview *service.OpsDashboardOverview
|
||||
trend *service.OpsThroughputTrendResponse
|
||||
errTrend *service.OpsErrorTrendResponse
|
||||
)
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetDashboardOverview(gctx, &f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
overview = result
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
trend = result
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
errTrend = result
|
||||
return nil
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := &opsDashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Overview: overview,
|
||||
ThroughputTrend: trend,
|
||||
ErrorTrend: errTrend,
|
||||
}
|
||||
|
||||
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
}
|
||||
155
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
155
backend/internal/handler/admin/scheduled_test_handler.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ScheduledTestHandler handles admin scheduled-test-plan management.
|
||||
type ScheduledTestHandler struct {
|
||||
scheduledTestSvc *service.ScheduledTestService
|
||||
}
|
||||
|
||||
// NewScheduledTestHandler creates a new ScheduledTestHandler.
|
||||
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
|
||||
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
|
||||
}
|
||||
|
||||
type createScheduledTestPlanRequest struct {
|
||||
AccountID int64 `json:"account_id" binding:"required"`
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type updateScheduledTestPlanRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
CronExpression string `json:"cron_expression"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid account id")
|
||||
return
|
||||
}
|
||||
|
||||
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, plans)
|
||||
}
|
||||
|
||||
// Create POST /admin/scheduled-test-plans
|
||||
func (h *ScheduledTestHandler) Create(c *gin.Context) {
|
||||
var req createScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
plan := &service.ScheduledTestPlan{
|
||||
AccountID: req.AccountID,
|
||||
ModelID: req.ModelID,
|
||||
CronExpression: req.CronExpression,
|
||||
Enabled: true,
|
||||
MaxResults: req.MaxResults,
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
plan.Enabled = *req.Enabled
|
||||
}
|
||||
|
||||
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// Update PUT /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Update(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "plan not found")
|
||||
return
|
||||
}
|
||||
|
||||
var req updateScheduledTestPlanRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.ModelID != "" {
|
||||
existing.ModelID = req.ModelID
|
||||
}
|
||||
if req.CronExpression != "" {
|
||||
existing.CronExpression = req.CronExpression
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
existing.Enabled = *req.Enabled
|
||||
}
|
||||
if req.MaxResults > 0 {
|
||||
existing.MaxResults = req.MaxResults
|
||||
}
|
||||
|
||||
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// Delete DELETE /admin/scheduled-test-plans/:id
|
||||
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
||||
}
|
||||
|
||||
// ListResults GET /admin/scheduled-test-plans/:id/results
|
||||
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
|
||||
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid plan id")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
|
||||
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
|
||||
if err != nil {
|
||||
response.InternalError(c, err.Error())
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, results)
|
||||
}
|
||||
@@ -77,6 +77,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
@@ -123,18 +124,20 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsQueryModeDefault: settings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -193,6 +196,9 @@ type UpdateSettingsRequest struct {
|
||||
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
|
||||
|
||||
MinClaudeCodeVersion string `json:"min_claude_code_version"`
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -422,49 +428,51 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -515,6 +523,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
@@ -561,6 +570,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -592,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
||||
changed = append(changed, "registration_email_suffix_whitelist")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
@@ -709,6 +722,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
|
||||
changed = append(changed, "min_claude_code_version")
|
||||
}
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
@@ -738,6 +754,18 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
|
||||
return normalized
|
||||
}
|
||||
|
||||
func equalStringSlice(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
@@ -791,7 +819,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
|
||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -877,7 +905,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
`
|
||||
|
||||
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
response.BadRequest(c, "Failed to send test email: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
95
backend/internal/handler/admin/snapshot_cache.go
Normal file
95
backend/internal/handler/admin/snapshot_cache.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type snapshotCacheEntry struct {
|
||||
ETag string
|
||||
Payload any
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type snapshotCache struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
items map[string]snapshotCacheEntry
|
||||
}
|
||||
|
||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||
if ttl <= 0 {
|
||||
ttl = 30 * time.Second
|
||||
}
|
||||
return &snapshotCache{
|
||||
ttl: ttl,
|
||||
items: make(map[string]snapshotCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
|
||||
if c == nil || key == "" {
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
c.mu.RLock()
|
||||
entry, ok := c.items[key]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
if now.After(entry.ExpiresAt) {
|
||||
c.mu.Lock()
|
||||
delete(c.items, key)
|
||||
c.mu.Unlock()
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||
if c == nil {
|
||||
return snapshotCacheEntry{}
|
||||
}
|
||||
entry := snapshotCacheEntry{
|
||||
ETag: buildETagFromAny(payload),
|
||||
Payload: payload,
|
||||
ExpiresAt: time.Now().Add(c.ttl),
|
||||
}
|
||||
if key == "" {
|
||||
return entry
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.items[key] = entry
|
||||
c.mu.Unlock()
|
||||
return entry
|
||||
}
|
||||
|
||||
func buildETagFromAny(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(raw)
|
||||
return "\"" + hex.EncodeToString(sum[:]) + "\""
|
||||
}
|
||||
|
||||
func parseBoolQueryWithDefault(raw string, def bool) bool {
|
||||
value := strings.TrimSpace(strings.ToLower(raw))
|
||||
if value == "" {
|
||||
return def
|
||||
}
|
||||
switch value {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return def
|
||||
}
|
||||
}
|
||||
128
backend/internal/handler/admin/snapshot_cache_test.go
Normal file
128
backend/internal/handler/admin/snapshot_cache_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSnapshotCache_SetAndGet(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
|
||||
entry := c.Set("key1", map[string]string{"hello": "world"})
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.NotNil(t, entry.Payload)
|
||||
|
||||
got, ok := c.Get("key1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, entry.ETag, got.ETag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_Expiration(t *testing.T) {
|
||||
c := newSnapshotCache(1 * time.Millisecond)
|
||||
|
||||
c.Set("key1", "value")
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
_, ok := c.Get("key1")
|
||||
require.False(t, ok, "expired entry should not be returned")
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
_, ok := c.Get("")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetMiss(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
_, ok := c.Get("nonexistent")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_NilReceiver(t *testing.T) {
|
||||
var c *snapshotCache
|
||||
_, ok := c.Get("key")
|
||||
require.False(t, ok)
|
||||
|
||||
entry := c.Set("key", "value")
|
||||
require.Empty(t, entry.ETag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
|
||||
// Set with empty key should return entry but not store it
|
||||
entry := c.Set("", "value")
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
|
||||
_, ok := c.Get("")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_DefaultTTL(t *testing.T) {
|
||||
c := newSnapshotCache(0)
|
||||
require.Equal(t, 30*time.Second, c.ttl)
|
||||
|
||||
c2 := newSnapshotCache(-1 * time.Second)
|
||||
require.Equal(t, 30*time.Second, c2.ttl)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
payload := map[string]int{"a": 1, "b": 2}
|
||||
|
||||
entry1 := c.Set("k1", payload)
|
||||
entry2 := c.Set("k2", payload)
|
||||
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
|
||||
}
|
||||
|
||||
func TestSnapshotCache_ETagFormat(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
entry := c.Set("k", "test")
|
||||
// ETag should be quoted hex string: "abcdef..."
|
||||
require.True(t, len(entry.ETag) > 2)
|
||||
require.Equal(t, byte('"'), entry.ETag[0])
|
||||
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
|
||||
}
|
||||
|
||||
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||
// channels are not JSON-serializable
|
||||
etag := buildETagFromAny(make(chan int))
|
||||
require.Empty(t, etag)
|
||||
}
|
||||
|
||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
def bool
|
||||
want bool
|
||||
}{
|
||||
{"empty returns default true", "", true, true},
|
||||
{"empty returns default false", "", false, false},
|
||||
{"1", "1", false, true},
|
||||
{"true", "true", false, true},
|
||||
{"TRUE", "TRUE", false, true},
|
||||
{"yes", "yes", false, true},
|
||||
{"on", "on", false, true},
|
||||
{"0", "0", true, false},
|
||||
{"false", "false", true, false},
|
||||
{"FALSE", "FALSE", true, false},
|
||||
{"no", "no", true, false},
|
||||
{"off", "off", true, false},
|
||||
{"whitespace trimmed", " true ", false, true},
|
||||
{"unknown returns default true", "maybe", true, true},
|
||||
{"unknown returns default false", "maybe", false, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := parseBoolQueryWithDefault(tc.raw, tc.def)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
exactTotal := false
|
||||
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
|
||||
parsed, err := strconv.ParseBool(exactTotalRaw)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid exact_total value, use true or false")
|
||||
return
|
||||
}
|
||||
exactTotal = parsed
|
||||
}
|
||||
|
||||
// Parse filters
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
@@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
ExactTotal: exactTotal,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
|
||||
@@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) {
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageListExactTotalTrue(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.True(t, repo.listFilters.ExactTotal)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct {
|
||||
Attributes map[int64]map[int64]string `json:"attributes"`
|
||||
}
|
||||
|
||||
var userAttributesBatchCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
// AttributeDefinitionResponse represents attribute definition response
|
||||
type AttributeDefinitionResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
@@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
userIDs := normalizeInt64IDList(req.UserIDs)
|
||||
if len(userIDs) == 0 {
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
}{
|
||||
UserIDs: userIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := userAttributesBatchCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
|
||||
payload := BatchUserAttributesResponse{Attributes: attrs}
|
||||
userAttributesBatchCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
@@ -91,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
Search: search,
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
if raw, ok := c.GetQuery("include_subscriptions"); ok {
|
||||
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
|
||||
filters.IncludeSubscriptions = &includeSubscriptions
|
||||
}
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -36,6 +37,11 @@ type CreateAPIKeyRequest struct {
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
Quota *float64 `json:"quota"` // 配额限制 (USD)
|
||||
ExpiresInDays *int `json:"expires_in_days"` // 过期天数
|
||||
|
||||
// Rate limit fields (0 = unlimited)
|
||||
RateLimit5h *float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d *float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d *float64 `json:"rate_limit_7d"`
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest represents the update API key request payload
|
||||
@@ -48,6 +54,12 @@ type UpdateAPIKeyRequest struct {
|
||||
Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制
|
||||
ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601)
|
||||
ResetQuota *bool `json:"reset_quota"` // 重置已用配额
|
||||
|
||||
// Rate limit fields (nil = no change, 0 = unlimited)
|
||||
RateLimit5h *float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d *float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d *float64 `json:"rate_limit_7d"`
|
||||
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // 重置限速用量
|
||||
}
|
||||
|
||||
// List handles listing user's API keys with pagination
|
||||
@@ -62,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
|
||||
// Parse filter parameters
|
||||
var filters service.APIKeyListFilters
|
||||
if search := strings.TrimSpace(c.Query("search")); search != "" {
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
filters.Search = search
|
||||
}
|
||||
filters.Status = c.Query("status")
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
gid, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err == nil {
|
||||
filters.GroupID = &gid
|
||||
}
|
||||
}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -131,6 +159,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
if req.Quota != nil {
|
||||
svcReq.Quota = *req.Quota
|
||||
}
|
||||
if req.RateLimit5h != nil {
|
||||
svcReq.RateLimit5h = *req.RateLimit5h
|
||||
}
|
||||
if req.RateLimit1d != nil {
|
||||
svcReq.RateLimit1d = *req.RateLimit1d
|
||||
}
|
||||
if req.RateLimit7d != nil {
|
||||
svcReq.RateLimit7d = *req.RateLimit7d
|
||||
}
|
||||
|
||||
executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
|
||||
@@ -163,10 +200,14 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
svcReq := service.UpdateAPIKeyRequest{
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
Quota: req.Quota,
|
||||
ResetQuota: req.ResetQuota,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
Quota: req.Quota,
|
||||
ResetQuota: req.ResetQuota,
|
||||
RateLimit5h: req.RateLimit5h,
|
||||
RateLimit1d: req.RateLimit1d,
|
||||
RateLimit7d: req.RateLimit7d,
|
||||
ResetRateLimitUsage: req.ResetRateLimitUsage,
|
||||
}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
|
||||
@@ -72,22 +72,31 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
IPWhitelist: k.IPWhitelist,
|
||||
IPBlacklist: k.IPBlacklist,
|
||||
LastUsedAt: k.LastUsedAt,
|
||||
Quota: k.Quota,
|
||||
QuotaUsed: k.QuotaUsed,
|
||||
ExpiresAt: k.ExpiresAt,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
IPWhitelist: k.IPWhitelist,
|
||||
IPBlacklist: k.IPBlacklist,
|
||||
LastUsedAt: k.LastUsedAt,
|
||||
Quota: k.Quota,
|
||||
QuotaUsed: k.QuotaUsed,
|
||||
ExpiresAt: k.ExpiresAt,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
RateLimit5h: k.RateLimit5h,
|
||||
RateLimit1d: k.RateLimit1d,
|
||||
RateLimit7d: k.RateLimit7d,
|
||||
Usage5h: k.Usage5h,
|
||||
Usage1d: k.Usage1d,
|
||||
Usage7d: k.Usage7d,
|
||||
Window5hStart: k.Window5hStart,
|
||||
Window1dStart: k.Window1dStart,
|
||||
Window7dStart: k.Window7dStart,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,13 +17,14 @@ type CustomMenuItem struct {
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -77,6 +78,9 @@ type SystemSettings struct {
|
||||
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
|
||||
|
||||
MinClaudeCodeVersion string `json:"min_claude_code_version"`
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -85,28 +89,29 @@ type DefaultSubscriptionSetting struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
Version string `json:"version"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||
|
||||
@@ -47,6 +47,17 @@ type APIKey struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// Rate limit fields
|
||||
RateLimit5h float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d float64 `json:"rate_limit_7d"`
|
||||
Usage5h float64 `json:"usage_5h"`
|
||||
Usage1d float64 `json:"usage_1d"`
|
||||
Usage7d float64 `json:"usage_7d"`
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -844,6 +845,10 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service
|
||||
|
||||
// Usage handles getting account balance and usage statistics for CC Switch integration
|
||||
// GET /v1/usage
|
||||
//
|
||||
// Two modes:
|
||||
// - quota_limited: API Key has quota or rate limits configured. Returns key-level limits/usage.
|
||||
// - unrestricted: No key-level limits. Returns subscription or wallet balance info.
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
@@ -857,54 +862,183 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 解析可选的日期范围参数(用于 model_stats 查询)
|
||||
startTime, endTime := h.parseUsageDateRange(c)
|
||||
|
||||
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
|
||||
var usageData gin.H
|
||||
usageData := h.buildUsageData(ctx, apiKey.ID)
|
||||
|
||||
// Best-effort: 获取模型统计
|
||||
var modelStats any
|
||||
if h.usageService != nil {
|
||||
dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID)
|
||||
if err == nil && dashStats != nil {
|
||||
usageData = gin.H{
|
||||
"today": gin.H{
|
||||
"requests": dashStats.TodayRequests,
|
||||
"input_tokens": dashStats.TodayInputTokens,
|
||||
"output_tokens": dashStats.TodayOutputTokens,
|
||||
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
|
||||
"cache_read_tokens": dashStats.TodayCacheReadTokens,
|
||||
"total_tokens": dashStats.TodayTokens,
|
||||
"cost": dashStats.TodayCost,
|
||||
"actual_cost": dashStats.TodayActualCost,
|
||||
},
|
||||
"total": gin.H{
|
||||
"requests": dashStats.TotalRequests,
|
||||
"input_tokens": dashStats.TotalInputTokens,
|
||||
"output_tokens": dashStats.TotalOutputTokens,
|
||||
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
|
||||
"cache_read_tokens": dashStats.TotalCacheReadTokens,
|
||||
"total_tokens": dashStats.TotalTokens,
|
||||
"cost": dashStats.TotalCost,
|
||||
"actual_cost": dashStats.TotalActualCost,
|
||||
},
|
||||
"average_duration_ms": dashStats.AverageDurationMs,
|
||||
"rpm": dashStats.Rpm,
|
||||
"tpm": dashStats.Tpm,
|
||||
if stats, err := h.usageService.GetAPIKeyModelStats(ctx, apiKey.ID, startTime, endTime); err == nil && len(stats) > 0 {
|
||||
modelStats = stats
|
||||
}
|
||||
}
|
||||
|
||||
// 判断模式: key 有总额度或速率限制 → quota_limited,否则 → unrestricted
|
||||
isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits()
|
||||
|
||||
if isQuotaLimited {
|
||||
h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats)
|
||||
return
|
||||
}
|
||||
|
||||
h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats)
|
||||
}
|
||||
|
||||
// parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围
|
||||
func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Time) {
|
||||
now := timezone.Now()
|
||||
endTime := now
|
||||
startTime := now.AddDate(0, 0, -30)
|
||||
|
||||
if s := c.Query("start_date"); s != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||
startTime = t
|
||||
}
|
||||
}
|
||||
if s := c.Query("end_date"); s != "" {
|
||||
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
|
||||
endTime = t.Add(24*time.Hour - time.Second) // end of day
|
||||
}
|
||||
}
|
||||
return startTime, endTime
|
||||
}
|
||||
|
||||
// buildUsageData 构建 today/total 用量摘要
|
||||
func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin.H {
|
||||
if h.usageService == nil {
|
||||
return nil
|
||||
}
|
||||
dashStats, err := h.usageService.GetAPIKeyDashboardStats(ctx, apiKeyID)
|
||||
if err != nil || dashStats == nil {
|
||||
return nil
|
||||
}
|
||||
return gin.H{
|
||||
"today": gin.H{
|
||||
"requests": dashStats.TodayRequests,
|
||||
"input_tokens": dashStats.TodayInputTokens,
|
||||
"output_tokens": dashStats.TodayOutputTokens,
|
||||
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
|
||||
"cache_read_tokens": dashStats.TodayCacheReadTokens,
|
||||
"total_tokens": dashStats.TodayTokens,
|
||||
"cost": dashStats.TodayCost,
|
||||
"actual_cost": dashStats.TodayActualCost,
|
||||
},
|
||||
"total": gin.H{
|
||||
"requests": dashStats.TotalRequests,
|
||||
"input_tokens": dashStats.TotalInputTokens,
|
||||
"output_tokens": dashStats.TotalOutputTokens,
|
||||
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
|
||||
"cache_read_tokens": dashStats.TotalCacheReadTokens,
|
||||
"total_tokens": dashStats.TotalTokens,
|
||||
"cost": dashStats.TotalCost,
|
||||
"actual_cost": dashStats.TotalActualCost,
|
||||
},
|
||||
"average_duration_ms": dashStats.AverageDurationMs,
|
||||
"rpm": dashStats.Rpm,
|
||||
"tpm": dashStats.Tpm,
|
||||
}
|
||||
}
|
||||
|
||||
// usageQuotaLimited 处理 quota_limited 模式的响应
|
||||
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) {
|
||||
resp := gin.H{
|
||||
"mode": "quota_limited",
|
||||
"isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired,
|
||||
"status": apiKey.Status,
|
||||
}
|
||||
|
||||
// 总额度信息
|
||||
if apiKey.Quota > 0 {
|
||||
remaining := apiKey.GetQuotaRemaining()
|
||||
resp["quota"] = gin.H{
|
||||
"limit": apiKey.Quota,
|
||||
"used": apiKey.QuotaUsed,
|
||||
"remaining": remaining,
|
||||
"unit": "USD",
|
||||
}
|
||||
resp["remaining"] = remaining
|
||||
resp["unit"] = "USD"
|
||||
}
|
||||
|
||||
// 速率限制信息(从 DB 获取实时用量)
|
||||
if apiKey.HasRateLimits() && h.apiKeyService != nil {
|
||||
rateLimitData, err := h.apiKeyService.GetRateLimitData(ctx, apiKey.ID)
|
||||
if err == nil && rateLimitData != nil {
|
||||
var rateLimits []gin.H
|
||||
if apiKey.RateLimit5h > 0 {
|
||||
used := rateLimitData.Usage5h
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
"window": "5h",
|
||||
"limit": apiKey.RateLimit5h,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit5h-used),
|
||||
"window_start": rateLimitData.Window5hStart,
|
||||
})
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 {
|
||||
used := rateLimitData.Usage1d
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
"window": "1d",
|
||||
"limit": apiKey.RateLimit1d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit1d-used),
|
||||
"window_start": rateLimitData.Window1dStart,
|
||||
})
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 {
|
||||
used := rateLimitData.Usage7d
|
||||
rateLimits = append(rateLimits, gin.H{
|
||||
"window": "7d",
|
||||
"limit": apiKey.RateLimit7d,
|
||||
"used": used,
|
||||
"remaining": max(0, apiKey.RateLimit7d-used),
|
||||
"window_start": rateLimitData.Window7dStart,
|
||||
})
|
||||
}
|
||||
if len(rateLimits) > 0 {
|
||||
resp["rate_limits"] = rateLimits
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 订阅模式:返回订阅限额信息 + 用量统计
|
||||
// 过期时间
|
||||
if apiKey.ExpiresAt != nil {
|
||||
resp["expires_at"] = apiKey.ExpiresAt
|
||||
resp["days_until_expiry"] = apiKey.GetDaysUntilExpiry()
|
||||
}
|
||||
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
if modelStats != nil {
|
||||
resp["model_stats"] = modelStats
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容)
|
||||
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) {
|
||||
// 订阅模式
|
||||
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
|
||||
subscription, ok := middleware2.GetSubscriptionFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
|
||||
return
|
||||
resp := gin.H{
|
||||
"mode": "unrestricted",
|
||||
"isValid": true,
|
||||
"planName": apiKey.Group.Name,
|
||||
"unit": "USD",
|
||||
}
|
||||
|
||||
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
|
||||
resp := gin.H{
|
||||
"isValid": true,
|
||||
"planName": apiKey.Group.Name,
|
||||
"remaining": remaining,
|
||||
"unit": "USD",
|
||||
"subscription": gin.H{
|
||||
// 订阅信息可能不在 context 中(/v1/usage 路径跳过了中间件的计费检查)
|
||||
subscription, ok := middleware2.GetSubscriptionFromContext(c)
|
||||
if ok {
|
||||
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
|
||||
resp["remaining"] = remaining
|
||||
resp["subscription"] = gin.H{
|
||||
"daily_usage_usd": subscription.DailyUsageUSD,
|
||||
"weekly_usage_usd": subscription.WeeklyUsageUSD,
|
||||
"monthly_usage_usd": subscription.MonthlyUsageUSD,
|
||||
@@ -912,23 +1046,28 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
|
||||
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
|
||||
"expires_at": subscription.ExpiresAt,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
if modelStats != nil {
|
||||
resp["model_stats"] = modelStats
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// 余额模式:返回钱包余额 + 用量统计
|
||||
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
// 余额模式
|
||||
latestUser, err := h.userService.GetByID(ctx, subject.UserID)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
|
||||
return
|
||||
}
|
||||
|
||||
resp := gin.H{
|
||||
"mode": "unrestricted",
|
||||
"isValid": true,
|
||||
"planName": "钱包余额",
|
||||
"remaining": latestUser.Balance,
|
||||
@@ -938,6 +1077,9 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
if usageData != nil {
|
||||
resp["usage"] = usageData
|
||||
}
|
||||
if modelStats != nil {
|
||||
resp["model_stats"] = modelStats
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -1445,6 +1587,18 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
}
|
||||
return http.StatusServiceUnavailable, "billing_service_error", msg
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
logger.L().With(
|
||||
|
||||
@@ -159,7 +159,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
@@ -27,6 +27,7 @@ type AdminHandlers struct {
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
ScheduledTest *admin.ScheduledTestHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
192
backend/internal/handler/openai_gateway_compact_log_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var handlerStructuredLogCaptureMu sync.Mutex
|
||||
|
||||
type handlerInMemoryLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
if event == nil {
|
||||
return
|
||||
}
|
||||
cloned := *event
|
||||
if event.Fields != nil {
|
||||
cloned.Fields = make(map[string]any, len(event.Fields))
|
||||
for k, v := range event.Fields {
|
||||
cloned.Fields[k] = v
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.events = append(s.events, &cloned)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
wantLevel := strings.ToLower(strings.TrimSpace(level))
|
||||
for _, ev := range s.events {
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *handlerInMemoryLogSink) ContainsFieldValue(field, substr string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, ev := range s.events {
|
||||
if ev == nil || ev.Fields == nil {
|
||||
continue
|
||||
}
|
||||
if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func captureHandlerStructuredLog(t *testing.T) (*handlerInMemoryLogSink, func()) {
|
||||
t.Helper()
|
||||
handlerStructuredLogCaptureMu.Lock()
|
||||
|
||||
err := logger.Init(logger.InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: logger.SamplingOptions{Enabled: false},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sink := &handlerInMemoryLogSink{}
|
||||
logger.SetSink(sink)
|
||||
return sink, func() {
|
||||
logger.SetSink(nil)
|
||||
handlerStructuredLogCaptureMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIRemoteCompactPath(t *testing.T) {
|
||||
require.False(t, isOpenAIRemoteCompactPath(nil))
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact/", nil)
|
||||
require.True(t, isOpenAIRemoteCompactPath(c))
|
||||
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
require.False(t, isOpenAIRemoteCompactPath(c))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Set(opsModelKey, "gpt-5.3-codex")
|
||||
c.Set(opsAccountIDKey, int64(123))
|
||||
c.Header("x-request-id", "rid-compact-ok")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now().Add(-8*time.Millisecond))
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "succeeded"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "200"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.3-codex"))
|
||||
require.True(t, logSink.ContainsFieldValue("account_id", "123"))
|
||||
require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-compact-ok"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
c.Status(http.StatusBadGateway)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("compact_outcome", "failed"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "502"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/responses/compact"))
|
||||
}
|
||||
|
||||
func TestLogOpenAIRemoteCompactOutcome_NonCompactSkips(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.logOpenAIRemoteCompactOutcome(c, time.Now())
|
||||
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.succeeded", "info"))
|
||||
require.False(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureHandlerStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
require.True(t, logSink.ContainsMessageAtLevel("codex.remote_compact.failed", "warn"))
|
||||
require.True(t, logSink.ContainsFieldValue("status_code", "401"))
|
||||
require.True(t, logSink.ContainsFieldValue("path", "/v1/responses/compact"))
|
||||
}
|
||||
@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -61,6 +62,7 @@ func NewOpenAIGatewayHandler(
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +72,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
compactStartedAt := time.Now()
|
||||
defer h.logOpenAIRemoteCompactOutcome(c, compactStartedAt)
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
|
||||
requestStart := time.Now()
|
||||
@@ -340,6 +344,86 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAIRemoteCompactPath(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return false
|
||||
}
|
||||
normalizedPath := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||
return strings.HasSuffix(normalizedPath, "/responses/compact")
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) logOpenAIRemoteCompactOutcome(c *gin.Context, startedAt time.Time) {
|
||||
if !isOpenAIRemoteCompactPath(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
path string
|
||||
status int
|
||||
)
|
||||
if c != nil {
|
||||
if c.Request != nil {
|
||||
ctx = c.Request.Context()
|
||||
if c.Request.URL != nil {
|
||||
path = strings.TrimSpace(c.Request.URL.Path)
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
status = c.Writer.Status()
|
||||
}
|
||||
}
|
||||
|
||||
outcome := "failed"
|
||||
if status >= 200 && status < 300 {
|
||||
outcome = "succeeded"
|
||||
}
|
||||
latencyMs := time.Since(startedAt).Milliseconds()
|
||||
if latencyMs < 0 {
|
||||
latencyMs = 0
|
||||
}
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Bool("remote_compact", true),
|
||||
zap.String("compact_outcome", outcome),
|
||||
zap.Int("status_code", status),
|
||||
zap.Int64("latency_ms", latencyMs),
|
||||
zap.String("path", path),
|
||||
zap.Bool("force_codex_cli", h != nil && h.cfg != nil && h.cfg.Gateway.ForceCodexCLI),
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
if userAgent := strings.TrimSpace(c.GetHeader("User-Agent")); userAgent != "" {
|
||||
fields = append(fields, zap.String("request_user_agent", userAgent))
|
||||
}
|
||||
if v, ok := c.Get(opsModelKey); ok {
|
||||
if model, ok := v.(string); ok && strings.TrimSpace(model) != "" {
|
||||
fields = append(fields, zap.String("request_model", strings.TrimSpace(model)))
|
||||
}
|
||||
}
|
||||
if v, ok := c.Get(opsAccountIDKey); ok {
|
||||
if accountID, ok := v.(int64); ok && accountID > 0 {
|
||||
fields = append(fields, zap.Int64("account_id", accountID))
|
||||
}
|
||||
}
|
||||
if c.Writer != nil {
|
||||
if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("x-request-id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
} else if upstreamRequestID := strings.TrimSpace(c.Writer.Header().Get("X-Request-Id")); upstreamRequestID != "" {
|
||||
fields = append(fields, zap.String("upstream_request_id", upstreamRequestID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log := logger.FromContext(ctx).With(fields...)
|
||||
if outcome == "succeeded" {
|
||||
log.Info("codex.remote_compact.succeeded")
|
||||
return
|
||||
}
|
||||
log.Warn("codex.remote_compact.failed")
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
|
||||
@@ -32,27 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
Version: h.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -996,7 +996,7 @@ func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*se
|
||||
}
|
||||
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
|
||||
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
|
||||
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||
@@ -1032,6 +1032,15 @@ func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64
|
||||
func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// newTestAPIKeyService 创建测试用的 APIKeyService
|
||||
func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
|
||||
@@ -2089,6 +2098,12 @@ func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context,
|
||||
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
|
||||
return r.accounts, nil
|
||||
}
|
||||
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
|
||||
return r.accounts, nil
|
||||
}
|
||||
func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
|
||||
return r.accounts, nil
|
||||
}
|
||||
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -182,6 +182,12 @@ func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platfo
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -405,7 +411,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
deferredService := service.NewDeferredService(accountRepo, nil, 0)
|
||||
billingService := service.NewBillingService(cfg, nil)
|
||||
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
|
||||
billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||
billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
t.Cleanup(func() {
|
||||
billingCacheService.Stop()
|
||||
})
|
||||
|
||||
@@ -30,6 +30,7 @@ func ProvideAdminHandlers(
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
scheduledTestHandler *admin.ScheduledTestHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -53,6 +54,7 @@ func ProvideAdminHandlers(
|
||||
UserAttribute: userAttributeHandler,
|
||||
ErrorPassthrough: errorPassthroughHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
ScheduledTest: scheduledTestHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +143,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewUserAttributeHandler,
|
||||
admin.NewErrorPassthroughHandler,
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
admin.NewScheduledTestHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -53,8 +53,7 @@ const (
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
// 默认值使用占位符,生产环境请通过环境变量注入真实值。
|
||||
var defaultClientSecret = "GOCSPX-your-client-secret"
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
|
||||
@@ -684,7 +684,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
||||
}
|
||||
if secret != "GOCSPX-your-client-secret" {
|
||||
if secret != "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" {
|
||||
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
|
||||
@@ -39,7 +39,7 @@ const (
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
|
||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||
|
||||
@@ -57,25 +57,28 @@ type DashboardStats struct {
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheTokens int64 `json:"cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Date string `json:"date"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// ModelStat represents usage statistics for a single model
|
||||
type ModelStat struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
@@ -154,6 +157,8 @@ type UsageLogFilters struct {
|
||||
BillingType *int8
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
|
||||
ExactTotal bool
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
|
||||
@@ -437,6 +437,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
switch status {
|
||||
case "rate_limited":
|
||||
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
||||
case "temp_unschedulable":
|
||||
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.And(
|
||||
entsql.Not(entsql.IsNull(col)),
|
||||
entsql.GT(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}))
|
||||
default:
|
||||
q = q.Where(dbaccount.StatusEQ(status))
|
||||
}
|
||||
@@ -640,7 +648,17 @@ func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
SetStatus(service.StatusActive).
|
||||
SetErrorMessage("").
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 清除临时不可调度状态,重置 401 升级链
|
||||
_, _ = r.sql.ExecContext(ctx, `
|
||||
UPDATE accounts
|
||||
SET temp_unschedulable_until = NULL,
|
||||
temp_unschedulable_reason = NULL
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
@@ -829,6 +847,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
now := time.Now()
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformEQ(platform),
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
dbaccount.Not(dbaccount.HasAccountGroups()),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
if len(platforms) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
now := time.Now()
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformIn(platforms...),
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
dbaccount.Not(dbaccount.HasAccountGroups()),
|
||||
tempUnschedulablePredicate(),
|
||||
notExpiredPredicate(now),
|
||||
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
|
||||
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
|
||||
).
|
||||
Order(dbent.Asc(dbaccount.FieldPriority)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||
if len(platforms) == 0 {
|
||||
return nil, nil
|
||||
|
||||
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewAPIKeyRepository(entClient)
|
||||
apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -16,10 +17,15 @@ import (
|
||||
|
||||
type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository {
|
||||
return newAPIKeyRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newAPIKeyRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *apiKeyRepository {
|
||||
return &apiKeyRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
@@ -37,7 +43,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
||||
SetNillableLastUsedAt(key.LastUsedAt).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetNillableExpiresAt(key.ExpiresAt)
|
||||
SetNillableExpiresAt(key.ExpiresAt).
|
||||
SetRateLimit5h(key.RateLimit5h).
|
||||
SetRateLimit1d(key.RateLimit1d).
|
||||
SetRateLimit7d(key.RateLimit7d)
|
||||
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
@@ -118,6 +127,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
apikey.FieldQuota,
|
||||
apikey.FieldQuotaUsed,
|
||||
apikey.FieldExpiresAt,
|
||||
apikey.FieldRateLimit5h,
|
||||
apikey.FieldRateLimit1d,
|
||||
apikey.FieldRateLimit7d,
|
||||
).
|
||||
WithUser(func(q *dbent.UserQuery) {
|
||||
q.Select(
|
||||
@@ -179,6 +191,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
SetStatus(key.Status).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetRateLimit5h(key.RateLimit5h).
|
||||
SetRateLimit1d(key.RateLimit1d).
|
||||
SetRateLimit7d(key.RateLimit7d).
|
||||
SetUsage5h(key.Usage5h).
|
||||
SetUsage1d(key.Usage1d).
|
||||
SetUsage7d(key.Usage7d).
|
||||
SetUpdatedAt(now)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
@@ -193,6 +211,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
// Rate limit window start times
|
||||
if key.Window5hStart != nil {
|
||||
builder.SetWindow5hStart(*key.Window5hStart)
|
||||
} else {
|
||||
builder.ClearWindow5hStart()
|
||||
}
|
||||
if key.Window1dStart != nil {
|
||||
builder.SetWindow1dStart(*key.Window1dStart)
|
||||
} else {
|
||||
builder.ClearWindow1dStart()
|
||||
}
|
||||
if key.Window7dStart != nil {
|
||||
builder.SetWindow7dStart(*key.Window7dStart)
|
||||
} else {
|
||||
builder.ClearWindow7dStart()
|
||||
}
|
||||
|
||||
// IP 限制字段
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
@@ -246,9 +281,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
// Apply filters
|
||||
if filters.Search != "" {
|
||||
q = q.Where(apikey.Or(
|
||||
apikey.NameContainsFold(filters.Search),
|
||||
apikey.KeyContainsFold(filters.Search),
|
||||
))
|
||||
}
|
||||
if filters.Status != "" {
|
||||
q = q.Where(apikey.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.GroupID != nil {
|
||||
if *filters.GroupID == 0 {
|
||||
q = q.Where(apikey.GroupIDIsNil())
|
||||
} else {
|
||||
q = q.Where(apikey.GroupIDEQ(*filters.GroupID))
|
||||
}
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -412,25 +465,92 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes
|
||||
// window start times via COALESCE if not already set.
|
||||
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = usage_5h + $1,
|
||||
usage_1d = usage_1d + $1,
|
||||
usage_7d = usage_7d + $1,
|
||||
window_5h_start = COALESCE(window_5h_start, NOW()),
|
||||
window_1d_start = COALESCE(window_1d_start, NOW()),
|
||||
window_7d_start = COALESCE(window_7d_start, NOW()),
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetRateLimitWindows resets expired rate limit windows atomically.
|
||||
func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRateLimitData returns the current rate limit usage and window start times for an API key.
|
||||
func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (result *service.APIKeyRateLimitData, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start
|
||||
FROM api_keys
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
if !rows.Next() {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
data := &service.APIKeyRateLimitData{}
|
||||
if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, rows.Err()
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
Quota: m.Quota,
|
||||
QuotaUsed: m.QuotaUsed,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
Quota: m.Quota,
|
||||
QuotaUsed: m.QuotaUsed,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
RateLimit5h: m.RateLimit5h,
|
||||
RateLimit1d: m.RateLimit1d,
|
||||
RateLimit7d: m.RateLimit7d,
|
||||
Usage5h: m.Usage5h,
|
||||
Usage1d: m.Usage1d,
|
||||
Usage7d: m.Usage7d,
|
||||
Window5hStart: m.Window5hStart,
|
||||
Window1dStart: m.Window1dStart,
|
||||
Window7dStart: m.Window7dStart,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
|
||||
@@ -26,7 +26,7 @@ func (s *APIKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
|
||||
s.repo = newAPIKeyRepositoryWithSQL(s.client, tx)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepoSuite(t *testing.T) {
|
||||
@@ -158,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
@@ -170,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(5), page.Total)
|
||||
@@ -314,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal(service.StatusDisabled, got2.Status)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(keys, 1)
|
||||
@@ -421,7 +421,7 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
|
||||
repo := NewAPIKeyRepository(client, integrationDB).(*apiKeyRepository)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户和 API Key
|
||||
|
||||
@@ -14,10 +14,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingRateLimitKeyPrefix = "apikey:rate:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
@@ -49,6 +51,20 @@ const (
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
// billingRateLimitKey generates the Redis key for API key rate limit cache.
|
||||
func billingRateLimitKey(keyID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
|
||||
}
|
||||
|
||||
const (
|
||||
rateLimitFieldUsage5h = "usage_5h"
|
||||
rateLimitFieldUsage1d = "usage_1d"
|
||||
rateLimitFieldUsage7d = "usage_7d"
|
||||
rateLimitFieldWindow5h = "window_5h"
|
||||
rateLimitFieldWindow1d = "window_1d"
|
||||
rateLimitFieldWindow7d = "window_7d"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
@@ -73,6 +89,21 @@ var (
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
|
||||
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
|
||||
updateRateLimitUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
@@ -195,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID,
|
||||
key := billingSubKey(userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
|
||||
key := billingRateLimitKey(keyID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
data := &service.APIKeyRateLimitCacheData{}
|
||||
if v, ok := result[rateLimitFieldUsage5h]; ok {
|
||||
data.Usage5h, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldUsage1d]; ok {
|
||||
data.Usage1d, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldUsage7d]; ok {
|
||||
data.Usage7d, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow5h]; ok {
|
||||
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow1d]; ok {
|
||||
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow7d]; ok {
|
||||
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
key := billingRateLimitKey(keyID)
|
||||
fields := map[string]any{
|
||||
rateLimitFieldUsage5h: data.Usage5h,
|
||||
rateLimitFieldUsage1d: data.Usage1d,
|
||||
rateLimitFieldUsage7d: data.Usage7d,
|
||||
rateLimitFieldWindow5h: data.Window5h,
|
||||
rateLimitFieldWindow1d: data.Window1d,
|
||||
rateLimitFieldWindow7d: data.Window7d,
|
||||
}
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, rateLimitCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
key := billingRateLimitKey(keyID)
|
||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
key := billingRateLimitKey(keyID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -66,6 +66,13 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
|
||||
},
|
||||
},
|
||||
"061_add_usage_log_request_type.sql": {
|
||||
fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
|
||||
acceptedDBChecksum: map[string]struct{}{
|
||||
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
|
||||
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
|
||||
|
||||
@@ -25,6 +25,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("061历史checksum可兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"061_add_usage_log_request_type.sql",
|
||||
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0",
|
||||
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
|
||||
)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("061第二个历史checksum可兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"061_add_usage_log_request_type.sql",
|
||||
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3",
|
||||
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
|
||||
)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("非白名单迁移不兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"001_init.sql",
|
||||
|
||||
183
backend/internal/repository/scheduled_test_repo.go
Normal file
183
backend/internal/repository/scheduled_test_repo.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// --- Plan Repository ---
|
||||
|
||||
type scheduledTestPlanRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewScheduledTestPlanRepository(db *sql.DB) service.ScheduledTestPlanRepository {
|
||||
return &scheduledTestPlanRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Create(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_plans (account_id, model_id, cron_expression, enabled, max_results, next_run_at, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.AccountID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) GetByID(ctx context.Context, id int64) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE id = $1
|
||||
`, id)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListByAccountID(ctx context.Context, accountID int64) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans WHERE account_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return scanPlans(rows)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) ListDue(ctx context.Context, now time.Time) ([]*service.ScheduledTestPlan, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
FROM scheduled_test_plans
|
||||
WHERE enabled = true AND next_run_at <= $1
|
||||
ORDER BY next_run_at ASC
|
||||
`, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
return scanPlans(rows)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Update(ctx context.Context, plan *service.ScheduledTestPlan) (*service.ScheduledTestPlan, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
UPDATE scheduled_test_plans
|
||||
SET model_id = $2, cron_expression = $3, enabled = $4, max_results = $5, next_run_at = $6, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, account_id, model_id, cron_expression, enabled, max_results, last_run_at, next_run_at, created_at, updated_at
|
||||
`, plan.ID, plan.ModelID, plan.CronExpression, plan.Enabled, plan.MaxResults, plan.NextRunAt)
|
||||
return scanPlan(row)
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM scheduled_test_plans WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *scheduledTestPlanRepository) UpdateAfterRun(ctx context.Context, id int64, lastRunAt time.Time, nextRunAt time.Time) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE scheduled_test_plans SET last_run_at = $2, next_run_at = $3, updated_at = NOW() WHERE id = $1
|
||||
`, id, lastRunAt, nextRunAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- Result Repository ---
|
||||
|
||||
type scheduledTestResultRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewScheduledTestResultRepository(db *sql.DB) service.ScheduledTestResultRepository {
|
||||
return &scheduledTestResultRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) Create(ctx context.Context, result *service.ScheduledTestResult) (*service.ScheduledTestResult, error) {
|
||||
row := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO scheduled_test_results (plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
RETURNING id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||
`, result.PlanID, result.Status, result.ResponseText, result.ErrorMessage, result.LatencyMs, result.StartedAt, result.FinishedAt)
|
||||
|
||||
out := &service.ScheduledTestResult{}
|
||||
if err := row.Scan(
|
||||
&out.ID, &out.PlanID, &out.Status, &out.ResponseText, &out.ErrorMessage,
|
||||
&out.LatencyMs, &out.StartedAt, &out.FinishedAt, &out.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) ListByPlanID(ctx context.Context, planID int64, limit int) ([]*service.ScheduledTestResult, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, plan_id, status, response_text, error_message, latency_ms, started_at, finished_at, created_at
|
||||
FROM scheduled_test_results
|
||||
WHERE plan_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2
|
||||
`, planID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var results []*service.ScheduledTestResult
|
||||
for rows.Next() {
|
||||
r := &service.ScheduledTestResult{}
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.PlanID, &r.Status, &r.ResponseText, &r.ErrorMessage,
|
||||
&r.LatencyMs, &r.StartedAt, &r.FinishedAt, &r.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
func (r *scheduledTestResultRepository) PruneOldResults(ctx context.Context, planID int64, keepCount int) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM scheduled_test_results
|
||||
WHERE id IN (
|
||||
SELECT id FROM (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY plan_id ORDER BY created_at DESC) AS rn
|
||||
FROM scheduled_test_results
|
||||
WHERE plan_id = $1
|
||||
) ranked
|
||||
WHERE rn > $2
|
||||
)
|
||||
`, planID, keepCount)
|
||||
return err
|
||||
}
|
||||
|
||||
// --- scan helpers ---
|
||||
|
||||
type scannable interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanPlan(row scannable) (*service.ScheduledTestPlan, error) {
|
||||
p := &service.ScheduledTestPlan{}
|
||||
if err := row.Scan(
|
||||
&p.ID, &p.AccountID, &p.ModelID, &p.CronExpression, &p.Enabled, &p.MaxResults,
|
||||
&p.LastRunAt, &p.NextRunAt, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func scanPlans(rows *sql.Rows) ([]*service.ScheduledTestPlan, error) {
|
||||
var plans []*service.ScheduledTestPlan
|
||||
for rows.Next() {
|
||||
p, err := scanPlan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plans = append(plans, p)
|
||||
}
|
||||
return plans, rows.Err()
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func (s *SettingRepoSuite) TestSet_EmptyValue() {
|
||||
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
|
||||
// 模拟保存站点设置,部分字段有值,部分字段为空
|
||||
settings := map[string]string{
|
||||
"site_name": "AICodex2API",
|
||||
"site_name": "Sub2api",
|
||||
"site_subtitle": "Subscription to API",
|
||||
"site_logo": "", // 用户未上传Logo
|
||||
"api_base_url": "", // 用户未设置API地址
|
||||
@@ -136,7 +136,7 @@ func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
|
||||
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
|
||||
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
|
||||
|
||||
s.Require().Equal("AICodex2API", result["site_name"])
|
||||
s.Require().Equal("Sub2api", result["site_name"])
|
||||
s.Require().Equal("Subscription to API", result["site_subtitle"])
|
||||
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
|
||||
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
repo := NewAPIKeyRepository(client, integrationDB)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
|
||||
@@ -73,7 +73,7 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
repo := NewAPIKeyRepository(client, integrationDB)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
|
||||
@@ -93,7 +93,7 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
repo := NewAPIKeyRepository(client, integrationDB)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
|
||||
|
||||
@@ -1363,7 +1363,8 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1401,6 +1402,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1473,7 +1476,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
}
|
||||
|
||||
whereClause := buildWhere(conditions)
|
||||
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params)
|
||||
var (
|
||||
logs []service.UsageLog
|
||||
page *pagination.PaginationResult
|
||||
err error
|
||||
)
|
||||
if shouldUseFastUsageLogTotal(filters) {
|
||||
logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params)
|
||||
} else {
|
||||
logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -1484,17 +1496,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
return logs, page, nil
|
||||
}
|
||||
|
||||
func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool {
|
||||
if filters.ExactTotal {
|
||||
return false
|
||||
}
|
||||
// 强选择过滤下记录集通常较小,保留精确总数。
|
||||
return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats = usagestats.UsageStats
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||
|
||||
func normalizePositiveInt64IDs(ids []int64) []int64 {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
out := make([]int64, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
|
||||
result := make(map[int64]*BatchUserUsageStats)
|
||||
if len(userIDs) == 0 {
|
||||
normalizedUserIDs := normalizePositiveInt64IDs(userIDs)
|
||||
if len(normalizedUserIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -1506,58 +1546,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range userIDs {
|
||||
for _, id := range normalizedUserIDs {
|
||||
result[id] = &BatchUserUsageStats{UserID: id}
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
SELECT
|
||||
user_id,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
WHERE user_id = ANY($1)
|
||||
AND created_at >= LEAST($2, $4)
|
||||
GROUP BY user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
|
||||
today := timezone.Today()
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
var total float64
|
||||
if err := rows.Scan(&userID, &total); err != nil {
|
||||
var todayTotal float64
|
||||
if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[userID]; ok {
|
||||
stats.TotalActualCost = total
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
today := timezone.Today()
|
||||
todayQuery := `
|
||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE user_id = ANY($1) AND created_at >= $2
|
||||
GROUP BY user_id
|
||||
`
|
||||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
var total float64
|
||||
if err := rows.Scan(&userID, &total); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[userID]; ok {
|
||||
stats.TodayActualCost = total
|
||||
stats.TodayActualCost = todayTotal
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
@@ -1577,7 +1595,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs)
|
||||
if len(normalizedAPIKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -1589,58 +1608,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
for _, id := range normalizedAPIKeyIDs {
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
SELECT
|
||||
api_key_id,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
WHERE api_key_id = ANY($1)
|
||||
AND created_at >= LEAST($2, $4)
|
||||
GROUP BY api_key_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
|
||||
today := timezone.Today()
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var apiKeyID int64
|
||||
var total float64
|
||||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
||||
var todayTotal float64
|
||||
if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[apiKeyID]; ok {
|
||||
stats.TotalActualCost = total
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
today := timezone.Today()
|
||||
todayQuery := `
|
||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = ANY($1) AND created_at >= $2
|
||||
GROUP BY api_key_id
|
||||
`
|
||||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var apiKeyID int64
|
||||
var total float64
|
||||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[apiKeyID]; ok {
|
||||
stats.TodayActualCost = total
|
||||
stats.TodayActualCost = todayTotal
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
@@ -1655,6 +1652,13 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
if shouldUsePreaggregatedTrend(granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) {
|
||||
aggregated, aggregatedErr := r.getUsageTrendFromAggregates(ctx, startTime, endTime, granularity)
|
||||
if aggregatedErr == nil && len(aggregated) > 0 {
|
||||
return aggregated, nil
|
||||
}
|
||||
}
|
||||
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
@@ -1663,7 +1667,8 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
@@ -1719,6 +1724,80 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func shouldUsePreaggregatedTrend(granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) bool {
|
||||
if granularity != "day" && granularity != "hour" {
|
||||
return false
|
||||
}
|
||||
return userID == 0 &&
|
||||
apiKeyID == 0 &&
|
||||
accountID == 0 &&
|
||||
groupID == 0 &&
|
||||
model == "" &&
|
||||
requestType == nil &&
|
||||
stream == nil &&
|
||||
billingType == nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
query := ""
|
||||
args := []any{startTime, endTime}
|
||||
|
||||
switch granularity {
|
||||
case "hour":
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
TO_CHAR(bucket_start, '%s') as date,
|
||||
total_requests as requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||
total_cost as cost,
|
||||
actual_cost
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
ORDER BY bucket_start ASC
|
||||
`, dateFormat)
|
||||
case "day":
|
||||
query = fmt.Sprintf(`
|
||||
SELECT
|
||||
TO_CHAR(bucket_date::timestamp, '%s') as date,
|
||||
total_requests as requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens,
|
||||
total_cost as cost,
|
||||
actual_cost
|
||||
FROM usage_dashboard_daily
|
||||
WHERE bucket_date >= $1::date AND bucket_date < $2::date
|
||||
ORDER BY bucket_date ASC
|
||||
`, dateFormat)
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results, err = scanTrendRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
@@ -1733,6 +1812,8 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
@@ -2166,6 +2247,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
|
||||
return logs, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
limit := params.Limit()
|
||||
offset := params.Offset()
|
||||
|
||||
limitPos := len(args) + 1
|
||||
offsetPos := len(args) + 2
|
||||
listArgs := append(append([]any{}, args...), limit+1, offset)
|
||||
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
|
||||
|
||||
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hasMore := false
|
||||
if len(logs) > limit {
|
||||
hasMore = true
|
||||
logs = logs[:limit]
|
||||
}
|
||||
|
||||
total := int64(offset) + int64(len(logs))
|
||||
if hasMore {
|
||||
// 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。
|
||||
total = int64(offset) + int64(limit) + 1
|
||||
}
|
||||
|
||||
return logs, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
@@ -2520,7 +2630,8 @@ func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
|
||||
&row.Requests,
|
||||
&row.InputTokens,
|
||||
&row.OutputTokens,
|
||||
&row.CacheTokens,
|
||||
&row.CacheCreationTokens,
|
||||
&row.CacheReadTokens,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
@@ -2544,6 +2655,8 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
|
||||
&row.Requests,
|
||||
&row.InputTokens,
|
||||
&row.OutputTokens,
|
||||
&row.CacheCreationTokens,
|
||||
&row.CacheReadTokens,
|
||||
&row.TotalTokens,
|
||||
&row.Cost,
|
||||
&row.ActualCost,
|
||||
|
||||
@@ -96,6 +96,7 @@ func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
filters := usagestats.UsageLogFilters{
|
||||
RequestType: &requestType,
|
||||
Stream: &stream,
|
||||
ExactTotal: true,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||
@@ -124,7 +125,7 @@ func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testin
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
@@ -143,7 +144,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -243,21 +243,24 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||
}
|
||||
|
||||
// Batch load active subscriptions with groups to avoid N+1.
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.UserIDIn(userIDs...),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
).
|
||||
WithGroup().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions
|
||||
if shouldLoadSubscriptions {
|
||||
// Batch load active subscriptions with groups to avoid N+1.
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.UserIDIn(userIDs...),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
).
|
||||
WithGroup().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for i := range subs {
|
||||
if u, ok := userMap[subs[i].UserID]; ok {
|
||||
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
||||
for i := range subs {
|
||||
if u, ok := userMap[subs[i].UserID]; ok {
|
||||
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,7 +53,9 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
||||
NewScheduledTestResultRepository, // 定时测试结果仓储
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
|
||||
@@ -86,6 +86,15 @@ func TestAPIContracts(t *testing.T) {
|
||||
"last_used_at": null,
|
||||
"quota": 0,
|
||||
"quota_used": 0,
|
||||
"rate_limit_5h": 0,
|
||||
"rate_limit_1d": 0,
|
||||
"rate_limit_7d": 0,
|
||||
"usage_5h": 0,
|
||||
"usage_1d": 0,
|
||||
"usage_7d": 0,
|
||||
"window_5h_start": null,
|
||||
"window_1d_start": null,
|
||||
"window_7d_start": null,
|
||||
"expires_at": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
@@ -126,6 +135,15 @@ func TestAPIContracts(t *testing.T) {
|
||||
"last_used_at": null,
|
||||
"quota": 0,
|
||||
"quota_used": 0,
|
||||
"rate_limit_5h": 0,
|
||||
"rate_limit_1d": 0,
|
||||
"rate_limit_7d": 0,
|
||||
"usage_5h": 0,
|
||||
"usage_1d": 0,
|
||||
"usage_7d": 0,
|
||||
"window_5h_start": null,
|
||||
"window_1d_start": null,
|
||||
"window_7d_start": null,
|
||||
"expires_at": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
@@ -428,9 +446,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
deps.settingRepo.SetAll(map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
|
||||
service.SettingKeySMTPHost: "smtp.example.com",
|
||||
service.SettingKeySMTPPort: "587",
|
||||
@@ -469,6 +488,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"data": {
|
||||
"registration_enabled": true,
|
||||
"email_verify_enabled": false,
|
||||
"registration_email_suffix_whitelist": [],
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"totp_enabled": false,
|
||||
@@ -514,6 +534,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"purchase_subscription_enabled": false,
|
||||
"purchase_subscription_url": "",
|
||||
"min_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"custom_menu_items": []
|
||||
}
|
||||
}`,
|
||||
@@ -1027,6 +1048,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1384,7 +1413,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
ids := make([]int64, 0, len(r.byID))
|
||||
for id := range r.byID {
|
||||
if r.byID[id].UserID == userID {
|
||||
@@ -1498,6 +1527,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubUsageLogRepo struct {
|
||||
userLogs map[int64][]service.UsageLog
|
||||
}
|
||||
|
||||
@@ -19,8 +19,16 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
//
|
||||
// 中间件职责分为两层:
|
||||
// - 鉴权(Authentication):验证 Key 有效性、用户状态、IP 限制 —— 始终执行
|
||||
// - 计费执行(Billing Enforcement):过期/配额/订阅/余额检查 —— skipBilling 时整块跳过
|
||||
//
|
||||
// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// ── 1. 提取 API Key ──────────────────────────────────────────
|
||||
|
||||
queryKey := strings.TrimSpace(c.Query("key"))
|
||||
queryApiKey := strings.TrimSpace(c.Query("api_key"))
|
||||
if queryKey != "" || queryApiKey != "" {
|
||||
@@ -56,7 +64,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库验证API key
|
||||
// ── 2. 验证 Key 存在 ─────────────────────────────────────────
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
@@ -67,29 +76,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API key是否激活
|
||||
if !apiKey.IsActive() {
|
||||
// Provide more specific error message based on status
|
||||
switch apiKey.Status {
|
||||
case service.StatusAPIKeyQuotaExhausted:
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
case service.StatusAPIKeyExpired:
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
default:
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
}
|
||||
return
|
||||
}
|
||||
// ── 3. 基础鉴权(始终执行) ─────────────────────────────────
|
||||
|
||||
// 检查API Key是否过期(即使状态是active,也要检查时间)
|
||||
if apiKey.IsExpired() {
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API Key配额是否耗尽
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
// disabled / 未知状态 → 无条件拦截(expired 和 quota_exhausted 留给计费阶段)
|
||||
if !apiKey.IsActive() &&
|
||||
apiKey.Status != service.StatusAPIKeyExpired &&
|
||||
apiKey.Status != service.StatusAPIKeyQuotaExhausted {
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -116,8 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// ── 4. SimpleMode → early return ─────────────────────────────
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
@@ -130,54 +124,89 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
// ── 5. 加载订阅(订阅模式时始终加载) ───────────────────────
|
||||
|
||||
// skipBilling: /v1/usage 只需鉴权,跳过所有计费执行
|
||||
skipBilling := c.Request.URL.Path == "/v1/usage"
|
||||
|
||||
var subscription *service.UserSubscription
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
// 订阅模式:获取订阅(L1 缓存 + singleflight)
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
sub, subErr := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
// 合并验证 + 限额检查(纯内存操作)
|
||||
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if err != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(err, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
if subErr != nil {
|
||||
if !skipBilling {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, status, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将订阅信息存入上下文
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
// skipBilling: 订阅不存在也放行,handler 会返回可用的数据
|
||||
} else {
|
||||
subscription = sub
|
||||
}
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
// ── 6. 计费执行(skipBilling 时整块跳过) ────────────────────
|
||||
|
||||
if !skipBilling {
|
||||
// Key 状态检查
|
||||
switch apiKey.Status {
|
||||
case service.StatusAPIKeyQuotaExhausted:
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
return
|
||||
case service.StatusAPIKeyExpired:
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// 运行时过期/配额检查(即使状态是 active,也要检查时间和用量)
|
||||
if apiKey.IsExpired() {
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
return
|
||||
}
|
||||
|
||||
// 订阅模式:验证订阅限额
|
||||
if subscription != nil {
|
||||
needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if validateErr != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(validateErr, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(validateErr, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(validateErr, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
}
|
||||
AbortWithError(c, status, code, validateErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── 7. 设置上下文 → Next ─────────────────────────────────────
|
||||
|
||||
if subscription != nil {
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
}
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
|
||||
@@ -56,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
@@ -95,6 +95,15 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return &service.APIKeyRateLimitData{}, nil
|
||||
}
|
||||
|
||||
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
|
||||
@@ -537,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -588,6 +588,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
|
||||
@@ -2,8 +2,11 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -71,3 +74,48 @@ func AbortWithError(c *gin.Context, statusCode int, code, message string) {
|
||||
c.JSON(statusCode, NewErrorResponse(code, message))
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// RequireGroupAssignment — 未分组 Key 拦截中间件
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
|
||||
type GatewayErrorWriter func(c *gin.Context, status int, message string)
|
||||
|
||||
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
|
||||
func AnthropicErrorWriter(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": "permission_error", "message": message},
|
||||
})
|
||||
}
|
||||
|
||||
// GoogleErrorWriter 按 Google API 规范输出错误
|
||||
func GoogleErrorWriter(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
|
||||
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
|
||||
func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKey, ok := GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey.GroupID != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
// 未分组 Key — 检查系统设置
|
||||
if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func SetupRouter(
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -96,6 +96,7 @@ func registerRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
@@ -110,5 +111,5 @@ func registerRoutes(
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@@ -78,6 +78,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// API Key 管理
|
||||
registerAdminAPIKeyRoutes(admin, h)
|
||||
|
||||
// 定时测试计划
|
||||
registerScheduledTestRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,6 +171,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
|
||||
|
||||
// Dashboard (vNext - raw path for MVP)
|
||||
ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2)
|
||||
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
|
||||
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
|
||||
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
|
||||
@@ -180,6 +184,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2)
|
||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
@@ -476,6 +481,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
plans := admin.Group("/scheduled-test-plans")
|
||||
{
|
||||
plans.POST("", h.Admin.ScheduledTest.Create)
|
||||
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
|
||||
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
|
||||
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
|
||||
}
|
||||
// Nested under accounts
|
||||
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
|
||||
}
|
||||
|
||||
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
rules := admin.Group("/error-passthrough-rules")
|
||||
{
|
||||
|
||||
@@ -19,6 +19,7 @@ func RegisterGatewayRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
||||
@@ -30,12 +31,17 @@ func RegisterGatewayRoutes(
|
||||
clientRequestID := middleware.ClientRequestID()
|
||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||
|
||||
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
|
||||
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
|
||||
requireGroupGoogle := middleware.RequireGroupAssignment(settingService, middleware.GoogleErrorWriter)
|
||||
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(bodyLimit)
|
||||
gateway.Use(clientRequestID)
|
||||
gateway.Use(opsErrorLogger)
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
gateway.Use(requireGroupAnthropic)
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
@@ -61,6 +67,7 @@ func RegisterGatewayRoutes(
|
||||
gemini.Use(clientRequestID)
|
||||
gemini.Use(opsErrorLogger)
|
||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
gemini.Use(requireGroupGoogle)
|
||||
{
|
||||
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
@@ -69,11 +76,11 @@ func RegisterGatewayRoutes(
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||
|
||||
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
|
||||
antigravityV1 := r.Group("/antigravity/v1")
|
||||
@@ -82,6 +89,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1.Use(opsErrorLogger)
|
||||
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
antigravityV1.Use(requireGroupAnthropic)
|
||||
{
|
||||
antigravityV1.POST("/messages", h.Gateway.Messages)
|
||||
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
@@ -95,6 +103,7 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta.Use(opsErrorLogger)
|
||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
antigravityV1Beta.Use(requireGroupGoogle)
|
||||
{
|
||||
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
@@ -108,6 +117,7 @@ func RegisterGatewayRoutes(
|
||||
soraV1.Use(opsErrorLogger)
|
||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
soraV1.Use(requireGroupAnthropic)
|
||||
{
|
||||
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
||||
soraV1.GET("/models", h.Gateway.Models)
|
||||
|
||||
@@ -853,15 +853,21 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
OpenAIWSIngressModeCtxPool = "ctx_pool"
|
||||
OpenAIWSIngressModePassthrough = "passthrough"
|
||||
)
|
||||
|
||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return OpenAIWSIngressModeOff
|
||||
case OpenAIWSIngressModeCtxPool:
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
case OpenAIWSIngressModePassthrough:
|
||||
return OpenAIWSIngressModePassthrough
|
||||
case OpenAIWSIngressModeShared:
|
||||
return OpenAIWSIngressModeShared
|
||||
case OpenAIWSIngressModeDedicated:
|
||||
@@ -873,18 +879,21 @@ func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
|
||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||
if normalized == OpenAIWSIngressModeShared || normalized == OpenAIWSIngressModeDedicated {
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
return OpenAIWSIngressModeShared
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 分类型 mode 新字段(string)
|
||||
// 2. 分类型 enabled 旧字段(bool)
|
||||
// 3. 兼容 enabled 旧字段(bool)
|
||||
// 4. defaultMode(非法时回退 shared)
|
||||
// 4. defaultMode(非法时回退 ctx_pool)
|
||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
@@ -919,7 +928,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
return "", false
|
||||
}
|
||||
if enabled {
|
||||
return OpenAIWSIngressModeShared, true
|
||||
return OpenAIWSIngressModeCtxPool, true
|
||||
}
|
||||
return OpenAIWSIngressModeOff, true
|
||||
}
|
||||
@@ -946,6 +955,10 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
// 兼容旧值:shared/dedicated 语义都归并到 ctx_pool。
|
||||
if resolvedDefault == OpenAIWSIngressModeShared || resolvedDefault == OpenAIWSIngressModeDedicated {
|
||||
return OpenAIWSIngressModeCtxPool
|
||||
}
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
|
||||
@@ -206,14 +206,14 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
t.Run("default fallback to shared", func(t *testing.T) {
|
||||
t.Run("default fallback to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
})
|
||||
|
||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||
@@ -221,15 +221,15 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": false,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("legacy enabled maps to shared", func(t *testing.T) {
|
||||
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
@@ -237,7 +237,28 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
})
|
||||
|
||||
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
|
||||
shared := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||
},
|
||||
}
|
||||
dedicated := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
|
||||
})
|
||||
|
||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||
@@ -249,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("non openai always off", func(t *testing.T) {
|
||||
|
||||
@@ -54,6 +54,8 @@ type AccountRepository interface {
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
||||
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||
|
||||
@@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
|
||||
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableUngroupedByPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableUngroupedByPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -33,7 +34,7 @@ import (
|
||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
|
||||
soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
|
||||
@@ -238,7 +239,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages?beta=true"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
@@ -1560,3 +1561,62 @@ func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) er
|
||||
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
||||
return fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
|
||||
// RunTestBackground executes an account test in-memory (no real HTTP client),
|
||||
// capturing SSE output via httptest.NewRecorder, then parses the result.
|
||||
func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID int64, modelID string) (*ScheduledTestResult, error) {
|
||||
startedAt := time.Now()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(w)
|
||||
ginCtx.Request = (&http.Request{}).WithContext(ctx)
|
||||
|
||||
testErr := s.TestAccountConnection(ginCtx, accountID, modelID)
|
||||
|
||||
finishedAt := time.Now()
|
||||
body := w.Body.String()
|
||||
responseText, errMsg := parseTestSSEOutput(body)
|
||||
|
||||
status := "success"
|
||||
if testErr != nil || errMsg != "" {
|
||||
status = "failed"
|
||||
if errMsg == "" && testErr != nil {
|
||||
errMsg = testErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return &ScheduledTestResult{
|
||||
Status: status,
|
||||
ResponseText: responseText,
|
||||
ErrorMessage: errMsg,
|
||||
LatencyMs: finishedAt.Sub(startedAt).Milliseconds(),
|
||||
StartedAt: startedAt,
|
||||
FinishedAt: finishedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseTestSSEOutput extracts response text and error message from captured SSE output.
|
||||
func parseTestSSEOutput(body string) (responseText, errMsg string) {
|
||||
var texts []string
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
var event TestEvent
|
||||
if err := json.Unmarshal([]byte(jsonStr), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
switch event.Type {
|
||||
case "content":
|
||||
if event.Text != "" {
|
||||
texts = append(texts, event.Text)
|
||||
}
|
||||
case "error":
|
||||
errMsg = event.Error
|
||||
}
|
||||
}
|
||||
responseText = strings.Join(texts, "")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -745,7 +745,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string)
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||
@@ -127,6 +127,15 @@ func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64
|
||||
func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, int64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type groupRepoStubForGroupUpdate struct {
|
||||
|
||||
@@ -348,6 +348,19 @@ func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, user
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
panic("unexpected GetAPIKeyRateLimit call")
|
||||
}
|
||||
func (s *billingCacheStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
panic("unexpected SetAPIKeyRateLimit call")
|
||||
}
|
||||
func (s *billingCacheStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
panic("unexpected UpdateAPIKeyRateLimitUsage call")
|
||||
}
|
||||
func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
panic("unexpected InvalidateAPIKeyRateLimit call")
|
||||
}
|
||||
|
||||
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
|
||||
t.Helper()
|
||||
calls := make([]subscriptionInvalidateCall, 0, expected)
|
||||
|
||||
@@ -36,12 +36,28 @@ type APIKey struct {
|
||||
Quota float64 // Quota limit in USD (0 = unlimited)
|
||||
QuotaUsed float64 // Used quota amount
|
||||
ExpiresAt *time.Time // Expiration time (nil = never expires)
|
||||
|
||||
// Rate limit fields
|
||||
RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited)
|
||||
RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited)
|
||||
RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited)
|
||||
Usage5h float64 // Used amount in current 5h window
|
||||
Usage1d float64 // Used amount in current 1d window
|
||||
Usage7d float64 // Used amount in current 7d window
|
||||
Window5hStart *time.Time // Start of current 5h window
|
||||
Window1dStart *time.Time // Start of current 1d window
|
||||
Window7dStart *time.Time // Start of current 7d window
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
// HasRateLimits returns true if any rate limit window is configured
|
||||
func (k *APIKey) HasRateLimits() bool {
|
||||
return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0
|
||||
}
|
||||
|
||||
// IsExpired checks if the API key has expired
|
||||
func (k *APIKey) IsExpired() bool {
|
||||
if k.ExpiresAt == nil {
|
||||
@@ -81,3 +97,10 @@ func (k *APIKey) GetDaysUntilExpiry() int {
|
||||
}
|
||||
return int(duration.Hours() / 24)
|
||||
}
|
||||
|
||||
// APIKeyListFilters holds optional filtering parameters for listing API keys.
|
||||
type APIKeyListFilters struct {
|
||||
Search string
|
||||
Status string
|
||||
GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组
|
||||
}
|
||||
|
||||
@@ -19,6 +19,11 @@ type APIKeyAuthSnapshot struct {
|
||||
|
||||
// Expiration field for API Key expiration feature
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
|
||||
|
||||
// Rate limit configuration (only limits, not usage - usage read from Redis at check time)
|
||||
RateLimit5h float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d float64 `json:"rate_limit_7d"`
|
||||
}
|
||||
|
||||
// APIKeyAuthUserSnapshot 用户快照
|
||||
|
||||
@@ -209,6 +209,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
Quota: apiKey.Quota,
|
||||
QuotaUsed: apiKey.QuotaUsed,
|
||||
ExpiresAt: apiKey.ExpiresAt,
|
||||
RateLimit5h: apiKey.RateLimit5h,
|
||||
RateLimit1d: apiKey.RateLimit1d,
|
||||
RateLimit7d: apiKey.RateLimit7d,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: apiKey.User.ID,
|
||||
Status: apiKey.User.Status,
|
||||
@@ -262,6 +265,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
Quota: snapshot.Quota,
|
||||
QuotaUsed: snapshot.QuotaUsed,
|
||||
ExpiresAt: snapshot.ExpiresAt,
|
||||
RateLimit5h: snapshot.RateLimit5h,
|
||||
RateLimit1d: snapshot.RateLimit1d,
|
||||
RateLimit7d: snapshot.RateLimit7d,
|
||||
User: &User{
|
||||
ID: snapshot.User.ID,
|
||||
Status: snapshot.User.Status,
|
||||
|
||||
@@ -30,6 +30,11 @@ var (
|
||||
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
|
||||
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
|
||||
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
|
||||
|
||||
// Rate limit errors
|
||||
ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完")
|
||||
ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完")
|
||||
ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -50,7 +55,7 @@ type APIKeyRepository interface {
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
@@ -64,6 +69,21 @@ type APIKeyRepository interface {
|
||||
// Quota methods
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
|
||||
UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error
|
||||
|
||||
// Rate limit methods
|
||||
IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error
|
||||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||||
GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error)
|
||||
}
|
||||
|
||||
// APIKeyRateLimitData holds rate limit usage and window state for an API key.
|
||||
type APIKeyRateLimitData struct {
|
||||
Usage5h float64
|
||||
Usage1d float64
|
||||
Usage7d float64
|
||||
Window5hStart *time.Time
|
||||
Window1dStart *time.Time
|
||||
Window7dStart *time.Time
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
@@ -102,6 +122,11 @@ type CreateAPIKeyRequest struct {
|
||||
// Quota fields
|
||||
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
|
||||
|
||||
// Rate limit fields (0 = unlimited)
|
||||
RateLimit5h float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d float64 `json:"rate_limit_7d"`
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
@@ -117,22 +142,34 @@ type UpdateAPIKeyRequest struct {
|
||||
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
|
||||
ClearExpiration bool `json:"-"` // Clear expiration (internal use)
|
||||
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
|
||||
|
||||
// Rate limit fields (nil = no change, 0 = unlimited)
|
||||
RateLimit5h *float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d *float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d *float64 `json:"rate_limit_7d"`
|
||||
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset.
|
||||
type RateLimitCacheInvalidator interface {
|
||||
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
|
||||
}
|
||||
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
|
||||
lastUsedTouchSF singleflight.Group
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache APIKeyCache
|
||||
rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
|
||||
lastUsedTouchSF singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
@@ -158,6 +195,12 @@ func NewAPIKeyService(
|
||||
return svc
|
||||
}
|
||||
|
||||
// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator.
|
||||
// Called after construction (e.g. in wire) to avoid circular dependencies.
|
||||
func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) {
|
||||
s.rateLimitCacheInvalid = inv
|
||||
}
|
||||
|
||||
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
|
||||
if apiKey == nil {
|
||||
return
|
||||
@@ -327,6 +370,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
Quota: req.Quota,
|
||||
QuotaUsed: 0,
|
||||
RateLimit5h: req.RateLimit5h,
|
||||
RateLimit1d: req.RateLimit1d,
|
||||
RateLimit7d: req.RateLimit7d,
|
||||
}
|
||||
|
||||
// Set expiration time if specified
|
||||
@@ -346,8 +392,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
}
|
||||
@@ -519,6 +565,26 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
apiKey.IPWhitelist = req.IPWhitelist
|
||||
apiKey.IPBlacklist = req.IPBlacklist
|
||||
|
||||
// Update rate limit configuration
|
||||
if req.RateLimit5h != nil {
|
||||
apiKey.RateLimit5h = *req.RateLimit5h
|
||||
}
|
||||
if req.RateLimit1d != nil {
|
||||
apiKey.RateLimit1d = *req.RateLimit1d
|
||||
}
|
||||
if req.RateLimit7d != nil {
|
||||
apiKey.RateLimit7d = *req.RateLimit7d
|
||||
}
|
||||
resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage
|
||||
if resetRateLimit {
|
||||
apiKey.Usage5h = 0
|
||||
apiKey.Usage1d = 0
|
||||
apiKey.Usage7d = 0
|
||||
apiKey.Window5hStart = nil
|
||||
apiKey.Window1dStart = nil
|
||||
apiKey.Window7dStart = nil
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
@@ -526,6 +592,11 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
// Invalidate Redis rate limit cache so reset takes effect immediately
|
||||
if resetRateLimit && s.rateLimitCacheInvalid != nil {
|
||||
_ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -746,3 +817,16 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRateLimitData returns rate limit usage and window state for an API key.
|
||||
func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
return s.apiKeyRepo.GetRateLimitData(ctx, id)
|
||||
}
|
||||
|
||||
// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB.
|
||||
func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
if cost <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost)
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -106,6 +106,15 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount
|
||||
func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
func (s *authRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
panic("unexpected IncrementRateLimitUsage call")
|
||||
}
|
||||
func (s *authRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
panic("unexpected ResetRateLimitWindows call")
|
||||
}
|
||||
func (s *authRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected GetRateLimitData call")
|
||||
}
|
||||
|
||||
type authCacheStub struct {
|
||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
|
||||
@@ -81,7 +81,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -134,6 +134,18 @@ func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
panic("unexpected IncrementRateLimitUsage call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
panic("unexpected ResetRateLimitWindows call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected GetRateLimitData call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -33,6 +34,7 @@ var (
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
@@ -115,6 +117,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
if isReservedEmail(email) {
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// 检查是否需要邀请码
|
||||
var invitationRedeemCode *RedeemCode
|
||||
@@ -241,6 +246,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
if isReservedEmail(email) {
|
||||
return ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
@@ -279,6 +287,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
if isReservedEmail(email) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
@@ -624,6 +635,32 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
|
||||
if s.settingService == nil {
|
||||
return nil
|
||||
}
|
||||
whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx)
|
||||
if !IsRegistrationEmailSuffixAllowed(email, whitelist) {
|
||||
return buildEmailSuffixNotAllowedError(whitelist)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildEmailSuffixNotAllowedError(whitelist []string) error {
|
||||
if len(whitelist) == 0 {
|
||||
return ErrEmailSuffixNotAllowed
|
||||
}
|
||||
|
||||
allowed := strings.Join(whitelist, ", ")
|
||||
return infraerrors.BadRequest(
|
||||
"EMAIL_SUFFIX_NOT_ALLOWED",
|
||||
fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed),
|
||||
).WithMetadata(map[string]string{
|
||||
"allowed_suffixes": strings.Join(whitelist, ","),
|
||||
"allowed_suffix_count": strconv.Itoa(len(whitelist)),
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -231,6 +232,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrEmailReserved)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@other.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Contains(t, appErr.Message, "@example.com")
|
||||
require.Contains(t, appErr.Message, "@company.com")
|
||||
require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason)
|
||||
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
|
||||
require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"])
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 8}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
|
||||
}, nil)
|
||||
|
||||
_, user, err := service.Register(context.Background(), "user@example.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, int64(8), user.ID)
|
||||
}
|
||||
|
||||
func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||
}, nil)
|
||||
|
||||
err := service.SendVerifyCode(context.Background(), "user@other.com")
|
||||
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Contains(t, appErr.Message, "@example.com")
|
||||
require.Contains(t, appErr.Message, "@company.com")
|
||||
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
@@ -402,7 +448,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 42}
|
||||
assigner := &defaultSubscriptionAssignerStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
|
||||
}, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
|
||||
@@ -40,6 +40,7 @@ const (
|
||||
cacheWriteSetSubscription
|
||||
cacheWriteUpdateSubscriptionUsage
|
||||
cacheWriteDeductBalance
|
||||
cacheWriteUpdateRateLimitUsage
|
||||
)
|
||||
|
||||
// 异步缓存写入工作池配置
|
||||
@@ -68,19 +69,26 @@ type cacheWriteTask struct {
|
||||
kind cacheWriteKind
|
||||
userID int64
|
||||
groupID int64
|
||||
apiKeyID int64
|
||||
balance float64
|
||||
amount float64
|
||||
subscriptionData *subscriptionCacheData
|
||||
}
|
||||
|
||||
// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB.
|
||||
type apiKeyRateLimitLoader interface {
|
||||
GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error)
|
||||
}
|
||||
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
apiKeyRateLimitLoader apiKeyRateLimitLoader
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
@@ -96,12 +104,13 @@ type BillingCacheService struct {
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
apiKeyRateLimitLoader: apiKeyRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
svc.startCacheWriteWorkers()
|
||||
@@ -188,6 +197,12 @@ func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteUpdateRateLimitUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
@@ -204,6 +219,8 @@ func cacheWriteKindName(kind cacheWriteKind) string {
|
||||
return "update_subscription_usage"
|
||||
case cacheWriteDeductBalance:
|
||||
return "deduct_balance"
|
||||
case cacheWriteUpdateRateLimitUsage:
|
||||
return "update_rate_limit_usage"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
@@ -476,6 +493,137 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// API Key 限速缓存方法
|
||||
// ============================================
|
||||
|
||||
// checkAPIKeyRateLimits checks rate limit windows for an API key.
|
||||
// It loads usage from Redis cache (falling back to DB on cache miss),
|
||||
// resets expired windows in-memory and triggers async DB reset,
|
||||
// and returns an error if any window limit is exceeded.
|
||||
func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error {
|
||||
if s.cache == nil {
|
||||
// No cache: fall back to reading from DB directly
|
||||
if s.apiKeyRateLimitLoader == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
|
||||
if err != nil {
|
||||
return nil // Don't block requests on DB errors
|
||||
}
|
||||
return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d,
|
||||
data.Window5hStart, data.Window1dStart, data.Window7dStart)
|
||||
}
|
||||
|
||||
cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID)
|
||||
if err != nil {
|
||||
// Cache miss: load from DB and populate cache
|
||||
if s.apiKeyRateLimitLoader == nil {
|
||||
return nil
|
||||
}
|
||||
dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
|
||||
if dbErr != nil {
|
||||
return nil // Don't block requests on DB errors
|
||||
}
|
||||
// Build cache entry from DB data
|
||||
cacheEntry := &APIKeyRateLimitCacheData{
|
||||
Usage5h: dbData.Usage5h,
|
||||
Usage1d: dbData.Usage1d,
|
||||
Usage7d: dbData.Usage7d,
|
||||
}
|
||||
if dbData.Window5hStart != nil {
|
||||
cacheEntry.Window5h = dbData.Window5hStart.Unix()
|
||||
}
|
||||
if dbData.Window1dStart != nil {
|
||||
cacheEntry.Window1d = dbData.Window1dStart.Unix()
|
||||
}
|
||||
if dbData.Window7dStart != nil {
|
||||
cacheEntry.Window7d = dbData.Window7dStart.Unix()
|
||||
}
|
||||
_ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry)
|
||||
cacheData = cacheEntry
|
||||
}
|
||||
|
||||
var w5h, w1d, w7d *time.Time
|
||||
if cacheData.Window5h > 0 {
|
||||
t := time.Unix(cacheData.Window5h, 0)
|
||||
w5h = &t
|
||||
}
|
||||
if cacheData.Window1d > 0 {
|
||||
t := time.Unix(cacheData.Window1d, 0)
|
||||
w1d = &t
|
||||
}
|
||||
if cacheData.Window7d > 0 {
|
||||
t := time.Unix(cacheData.Window7d, 0)
|
||||
w7d = &t
|
||||
}
|
||||
return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d)
|
||||
}
|
||||
|
||||
// evaluateRateLimits checks usage against limits, triggering async resets for expired windows.
|
||||
func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error {
|
||||
needsReset := false
|
||||
|
||||
// Reset expired windows in-memory for check purposes
|
||||
if w5h != nil && time.Since(*w5h) >= 5*time.Hour {
|
||||
usage5h = 0
|
||||
needsReset = true
|
||||
}
|
||||
if w1d != nil && time.Since(*w1d) >= 24*time.Hour {
|
||||
usage1d = 0
|
||||
needsReset = true
|
||||
}
|
||||
if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour {
|
||||
usage7d = 0
|
||||
needsReset = true
|
||||
}
|
||||
|
||||
// Trigger async DB reset if any window expired
|
||||
if needsReset {
|
||||
keyID := apiKey.ID
|
||||
go func() {
|
||||
resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if s.apiKeyRateLimitLoader != nil {
|
||||
// Use the repo directly - reset then reload cache
|
||||
if loader, ok := s.apiKeyRateLimitLoader.(interface {
|
||||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||||
}); ok {
|
||||
_ = loader.ResetRateLimitWindows(resetCtx, keyID)
|
||||
}
|
||||
}
|
||||
// Invalidate cache so next request loads fresh data
|
||||
if s.cache != nil {
|
||||
_ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Check limits
|
||||
if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h {
|
||||
return ErrAPIKeyRateLimit5hExceeded
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d {
|
||||
return ErrAPIKeyRateLimit1dExceeded
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d {
|
||||
return ErrAPIKeyRateLimit7dExceeded
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache.
|
||||
func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateRateLimitUsage,
|
||||
apiKeyID: apiKeyID,
|
||||
amount: cost,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 统一检查方法
|
||||
// ============================================
|
||||
@@ -496,10 +644,23 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
|
||||
if isSubscriptionMode {
|
||||
return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription)
|
||||
if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := s.checkBalanceEligibility(ctx, user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.checkBalanceEligibility(ctx, user.ID)
|
||||
// Check API Key rate limits (applies to both billing modes)
|
||||
if apiKey != nil && apiKey.HasRateLimits() {
|
||||
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBalanceEligibility 检查余额模式资格
|
||||
|
||||
@@ -51,6 +51,22 @@ func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
return nil, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceLoadUserRepoStub struct {
|
||||
mockUserRepo
|
||||
calls atomic.Int64
|
||||
@@ -76,7 +92,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
delay: 80 * time.Millisecond,
|
||||
balance: 12.34,
|
||||
}
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
const goroutines = 16
|
||||
|
||||
@@ -52,9 +52,25 @@ func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
@@ -76,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
|
||||
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
svc.Stop()
|
||||
|
||||
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
||||
|
||||
@@ -10,6 +10,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis.
|
||||
type APIKeyRateLimitCacheData struct {
|
||||
Usage5h float64 `json:"usage_5h"`
|
||||
Usage1d float64 `json:"usage_1d"`
|
||||
Usage7d float64 `json:"usage_7d"`
|
||||
Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started
|
||||
Window1d int64 `json:"window_1d"`
|
||||
Window7d int64 `json:"window_7d"`
|
||||
}
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
@@ -23,6 +33,12 @@ type BillingCache interface {
|
||||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||
|
||||
// API Key rate limit operations
|
||||
GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error)
|
||||
SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error
|
||||
UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error
|
||||
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
|
||||
}
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
|
||||
@@ -74,11 +74,12 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组)
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
@@ -201,6 +202,9 @@ const (
|
||||
|
||||
// SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查)
|
||||
SettingKeyMinClaudeCodeVersion = "min_claude_code_version"
|
||||
|
||||
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403)
|
||||
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -88,6 +88,49 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled",
|
||||
account: &Account{
|
||||
ID: 14,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
account: &Account{
|
||||
|
||||
@@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
|
||||
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体")
|
||||
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
|
||||
|
||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||
@@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
require.True(t, ok)
|
||||
bodyBytes, ok := rawBody.([]byte)
|
||||
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
|
||||
require.Equal(t, body, bodyBytes)
|
||||
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
|
||||
@@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体")
|
||||
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
|
||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
|
||||
@@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
modelMapping map[string]any // nil = 不配置映射
|
||||
expectedModel string
|
||||
endpoint string // "messages" or "count_tokens"
|
||||
}{
|
||||
{
|
||||
name: "Forward: 无映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: nil,
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 空映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 模型不在映射表中时不改写",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 精确匹配映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "Forward: 通配符映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "messages",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 无映射配置时不改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: nil,
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 模型不在映射表中时不改写",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||
expectedModel: "claude-sonnet-4-20250514",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 精确匹配映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
{
|
||||
name: "CountTokens: 通配符映射应改写模型",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||
expectedModel: "claude-sonnet-4-5-20241022",
|
||||
endpoint: "count_tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: tt.model,
|
||||
}
|
||||
|
||||
credentials := map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
}
|
||||
if tt.modelMapping != nil {
|
||||
credentials["model_mapping"] = tt.modelMapping
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Name: "edge-case-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: credentials,
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
if tt.endpoint == "messages" {
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
parsed.Stream = false
|
||||
|
||||
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||
"Forward 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||
} else {
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
upstreamRespBody := `{"input_tokens":42}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||
"CountTokens 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields
|
||||
// 确保模型映射只替换 model 字段,不影响请求体中的其他字段
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
// 包含复杂字段的请求体:system、thinking、messages
|
||||
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
}
|
||||
|
||||
upstreamRespBody := `{"input_tokens":42}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Name: "preserve-fields-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||
},
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
sentBody := upstream.lastBody
|
||||
require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射")
|
||||
require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改")
|
||||
require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改")
|
||||
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改")
|
||||
require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改")
|
||||
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
|
||||
// 确保空模型名不会触发映射逻辑
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: "", // 空模型
|
||||
}
|
||||
|
||||
upstreamRespBody := `{"input_tokens":10}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 302,
|
||||
Name: "empty-model-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"model_mapping": map[string]any{"*": "claude-3-opus-20240229"},
|
||||
},
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
// 空模型名时,body 应原样透传,不应触发映射
|
||||
require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
363
backend/internal/service/gateway_group_isolation_test.go
Normal file
363
backend/internal/service/gateway_group_isolation_test.go
Normal file
@@ -0,0 +1,363 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// Part 1: isAccountInGroup 单元测试
|
||||
// ============================================================================
|
||||
|
||||
func TestIsAccountInGroup(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
groupID100 := int64(100)
|
||||
groupID200 := int64(200)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
groupID *int64
|
||||
expected bool
|
||||
}{
|
||||
// groupID == nil(无分组 API Key)
|
||||
{
|
||||
"nil_groupID_ungrouped_account_nil_groups",
|
||||
&Account{ID: 1, AccountGroups: nil},
|
||||
nil, true,
|
||||
},
|
||||
{
|
||||
"nil_groupID_ungrouped_account_empty_slice",
|
||||
&Account{ID: 2, AccountGroups: []AccountGroup{}},
|
||||
nil, true,
|
||||
},
|
||||
{
|
||||
"nil_groupID_grouped_account_single",
|
||||
&Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||
nil, false,
|
||||
},
|
||||
{
|
||||
"nil_groupID_grouped_account_multiple",
|
||||
&Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
|
||||
nil, false,
|
||||
},
|
||||
// groupID != nil(有分组 API Key)
|
||||
{
|
||||
"with_groupID_account_in_group",
|
||||
&Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||
&groupID100, true,
|
||||
},
|
||||
{
|
||||
"with_groupID_account_not_in_group",
|
||||
&Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}},
|
||||
&groupID100, false,
|
||||
},
|
||||
{
|
||||
"with_groupID_ungrouped_account",
|
||||
&Account{ID: 7, AccountGroups: nil},
|
||||
&groupID100, false,
|
||||
},
|
||||
{
|
||||
"with_groupID_multi_group_account_match_one",
|
||||
&Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
|
||||
&groupID200, true,
|
||||
},
|
||||
{
|
||||
"with_groupID_multi_group_account_no_match",
|
||||
&Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}},
|
||||
&groupID100, false,
|
||||
},
|
||||
// 防御性边界
|
||||
{
|
||||
"nil_account_nil_groupID",
|
||||
nil,
|
||||
nil, false,
|
||||
},
|
||||
{
|
||||
"nil_account_with_groupID",
|
||||
nil,
|
||||
&groupID100, false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.isAccountInGroup(tt.account, tt.groupID)
|
||||
require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Part 2: 分组隔离端到端调度测试
|
||||
// ============================================================================
|
||||
|
||||
// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform,覆写分组隔离相关方法。
|
||||
// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。
|
||||
type groupAwareMockAccountRepo struct {
|
||||
*mockAccountRepoForPlatform
|
||||
allAccounts []Account
|
||||
}
|
||||
|
||||
// ListSchedulableUngroupedByPlatform 仅返回未分组账号(AccountGroups 为空)
|
||||
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range m.allAccounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本)
|
||||
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
platformSet := make(map[string]bool, len(platforms))
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
var result []Account
|
||||
for _, acc := range m.allAccounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号
|
||||
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range m.allAccounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本)
|
||||
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
platformSet := make(map[string]bool, len(platforms))
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
var result []Account
|
||||
for _, acc := range m.allAccounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// accountBelongsToGroup 检查账号是否属于指定分组
|
||||
func accountBelongsToGroup(acc Account, groupID int64) bool {
|
||||
for _, ag := range acc.AccountGroups {
|
||||
if ag.GroupID == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*groupAwareMockAccountRepo)(nil)
|
||||
|
||||
// newGroupAwareMockRepo 创建分组感知的 mock repo
|
||||
func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo {
|
||||
byID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
byID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
return &groupAwareMockAccountRepo{
|
||||
mockAccountRepoForPlatform: &mockAccountRepoForPlatform{
|
||||
accounts: accounts,
|
||||
accountsByID: byID,
|
||||
},
|
||||
allAccounts: accounts,
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) {
|
||||
// 场景:无分组 API Key(groupID=nil),池中只有已分组账号 → 应返回错误
|
||||
ctx := context.Background()
|
||||
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 200}}},
|
||||
}
|
||||
repo := newGroupAwareMockRepo(accounts)
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||
require.Error(t, err, "无分组 Key 不应调度到已分组账号")
|
||||
require.Nil(t, acc)
|
||||
}
|
||||
|
||||
func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) {
|
||||
// 场景:有分组 API Key(groupID=100),池中只有未分组账号 → 应返回错误
|
||||
ctx := context.Background()
|
||||
groupID := int64(100)
|
||||
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: nil},
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{}},
|
||||
}
|
||||
repo := newGroupAwareMockRepo(accounts)
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
|
||||
require.Error(t, err, "有分组 Key 不应调度到未分组账号")
|
||||
require.Nil(t, acc)
|
||||
}
|
||||
|
||||
func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) {
|
||||
// 场景:无分组 API Key(groupID=nil),池中有未分组和已分组账号 → 应只选中未分组的
|
||||
ctx := context.Background()
|
||||
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: nil}, // 未分组,应被选中
|
||||
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中
|
||||
}
|
||||
repo := newGroupAwareMockRepo(accounts)
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||
require.NoError(t, err, "应成功调度未分组账号")
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2")
|
||||
}
|
||||
|
||||
func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) {
|
||||
// 场景:有分组 API Key(groupID=100),池中有未分组和多个分组账号 → 应只选中分组 100 内的
|
||||
ctx := context.Background()
|
||||
groupID := int64(100)
|
||||
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: nil}, // 未分组,不应被选中
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200,不应被选中
|
||||
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100,应被选中
|
||||
}
|
||||
repo := newGroupAwareMockRepo(accounts)
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
|
||||
require.NoError(t, err, "应成功调度分组内账号")
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Part 3: SimpleMode 旁路测试
|
||||
// ============================================================================
|
||||
|
||||
func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) {
|
||||
// SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。
|
||||
// 测试非 useMixed 路径(platform=openai,不会触发 mixed 调度逻辑)。
|
||||
ctx := context.Background()
|
||||
|
||||
// 混合未分组和已分组账号,SimpleMode 下应全部可调度
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: nil}, // 未分组
|
||||
}
|
||||
|
||||
// 使用基础 mock(ListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤)
|
||||
byID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
byID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: accounts,
|
||||
accountsByID: byID,
|
||||
}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: &config.Config{RunMode: config.RunModeSimple},
|
||||
}
|
||||
|
||||
// groupID=nil 时,SimpleMode 应使用 ListSchedulableByPlatform(不过滤分组)
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||
require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号")
|
||||
require.NotNil(t, acc)
|
||||
// 应选择优先级最高的账号(Priority=1, ID=2),即使它未分组
|
||||
require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组")
|
||||
}
|
||||
|
||||
func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) {
|
||||
// SimpleMode + groupID=nil 时,已分组账号也应该可被调度
|
||||
ctx := context.Background()
|
||||
|
||||
// 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||
}
|
||||
|
||||
byID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
byID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: accounts,
|
||||
accountsByID: byID,
|
||||
}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: &config.Config{RunMode: config.RunModeSimple},
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||
require.NoError(t, err, "SimpleMode 下已分组账号也应可调度")
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号")
|
||||
}
|
||||
@@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1782,8 +1782,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||
} else {
|
||||
} else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
@@ -1824,7 +1826,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
@@ -1964,14 +1966,15 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
|
||||
}
|
||||
|
||||
// isAccountInGroup checks if the account belongs to the specified group.
|
||||
// Returns true if groupID is nil (no group restriction) or account belongs to the group.
|
||||
// When groupID is nil, returns true only for ungrouped accounts (no group assignments).
|
||||
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
|
||||
if groupID == nil {
|
||||
return true // 无分组限制
|
||||
}
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if groupID == nil {
|
||||
// 无分组的 API Key 只能使用未分组的账号
|
||||
return len(account.AccountGroups) == 0
|
||||
}
|
||||
for _, ag := range account.AccountGroups {
|
||||
if ag.GroupID == *groupID {
|
||||
return true
|
||||
@@ -3886,7 +3889,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime)
|
||||
passthroughBody := parsed.Body
|
||||
passthroughModel := parsed.Model
|
||||
if passthroughModel != "" {
|
||||
if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel {
|
||||
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||
logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
passthroughModel = mappedModel
|
||||
}
|
||||
}
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
@@ -4571,7 +4583,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
@@ -4951,7 +4963,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
targetURL = validatedURL + "/v1/messages?beta=true"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6361,9 +6373,10 @@ type RecordUsageInput struct {
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
}
|
||||
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
||||
type APIKeyQuotaUpdater interface {
|
||||
UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -6557,6 +6570,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
@@ -6746,6 +6767,14 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
}
|
||||
}
|
||||
|
||||
// Update API Key rate limit usage
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
||||
}
|
||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
@@ -6761,7 +6790,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body)
|
||||
passthroughBody := parsed.Body
|
||||
if reqModel := parsed.Model; reqModel != "" {
|
||||
if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel {
|
||||
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||
logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
@@ -7052,7 +7088,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
@@ -7099,7 +7135,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -122,6 +122,28 @@ func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) {
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "gemini_apikey_temp_unschedulable_401_second_hit_returns_none",
|
||||
account: &Account{
|
||||
ID: 105,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "gemini_custom_codes_override_temp_unschedulable",
|
||||
account: &Account{
|
||||
|
||||
@@ -431,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co
|
||||
if groupID != nil {
|
||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
|
||||
@@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
|
||||
}
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user