mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-22 15:44:46 +08:00
feat(sync): full code sync from release
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -116,13 +116,12 @@ backend/.installed
|
|||||||
# ===================
|
# ===================
|
||||||
tests
|
tests
|
||||||
CLAUDE.md
|
CLAUDE.md
|
||||||
AGENTS.md
|
|
||||||
.claude
|
.claude
|
||||||
scripts
|
scripts
|
||||||
.code-review-state
|
.code-review-state
|
||||||
openspec/
|
#openspec/
|
||||||
code-reviews/
|
code-reviews/
|
||||||
AGENTS.md
|
#AGENTS.md
|
||||||
backend/cmd/server/server
|
backend/cmd/server/server
|
||||||
deploy/docker-compose.override.yml
|
deploy/docker-compose.override.yml
|
||||||
.gocache/
|
.gocache/
|
||||||
@@ -132,4 +131,5 @@ docs/*
|
|||||||
.codex/
|
.codex/
|
||||||
frontend/coverage/
|
frontend/coverage/
|
||||||
aicodex
|
aicodex
|
||||||
|
output/
|
||||||
|
|
||||||
|
|||||||
105
AGENTS.md
Normal file
105
AGENTS.md
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# Repository Guidelines
|
||||||
|
|
||||||
|
## Project Structure & Module Organization
|
||||||
|
- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output.
|
||||||
|
- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`.
|
||||||
|
- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`).
|
||||||
|
- `openspec/`: Spec-driven change docs (`changes/<id>/{proposal,design,tasks}.md`).
|
||||||
|
- `tools/`: Utility scripts (security/perf checks).
|
||||||
|
|
||||||
|
## Build, Test, and Development Commands
|
||||||
|
```bash
|
||||||
|
make build # Build backend + frontend
|
||||||
|
make test # Backend tests + frontend lint/typecheck
|
||||||
|
cd backend && make build # Build backend binary
|
||||||
|
cd backend && make test-unit # Go unit tests
|
||||||
|
cd backend && make test-integration # Go integration tests
|
||||||
|
cd backend && make test # go test ./... + golangci-lint
|
||||||
|
cd frontend && pnpm install --frozen-lockfile
|
||||||
|
cd frontend && pnpm dev # Vite dev server
|
||||||
|
cd frontend && pnpm build # Type-check + production build
|
||||||
|
cd frontend && pnpm test:run # Vitest run
|
||||||
|
cd frontend && pnpm test:coverage # Vitest + coverage report
|
||||||
|
python3 tools/secret_scan.py # Secret scan
|
||||||
|
```
|
||||||
|
|
||||||
|
## Coding Style & Naming Conventions
|
||||||
|
- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`).
|
||||||
|
- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard).
|
||||||
|
- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`.
|
||||||
|
- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`.
|
||||||
|
|
||||||
|
## Go & Frontend Development Standards
|
||||||
|
- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears.
|
||||||
|
- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency.
|
||||||
|
- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable.
|
||||||
|
|
||||||
|
### Go Performance Rules
|
||||||
|
- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code.
|
||||||
|
- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable.
|
||||||
|
- Every external I/O path must use `context.Context` with explicit timeout/cancel.
|
||||||
|
- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources.
|
||||||
|
- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`).
|
||||||
|
- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`.
|
||||||
|
- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths.
|
||||||
|
- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after).
|
||||||
|
- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out.
|
||||||
|
- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention.
|
||||||
|
- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible.
|
||||||
|
- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary.
|
||||||
|
- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions.
|
||||||
|
- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain.
|
||||||
|
- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible.
|
||||||
|
- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled.
|
||||||
|
|
||||||
|
### Data Access & Cache Rules
|
||||||
|
- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR.
|
||||||
|
- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints.
|
||||||
|
- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths.
|
||||||
|
- Keep transactions short; never perform external RPC/network calls inside DB transactions.
|
||||||
|
- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`).
|
||||||
|
- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks.
|
||||||
|
- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd.
|
||||||
|
- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths.
|
||||||
|
|
||||||
|
### Frontend Performance Rules
|
||||||
|
- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components.
|
||||||
|
- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs.
|
||||||
|
- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it.
|
||||||
|
- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed.
|
||||||
|
- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows).
|
||||||
|
- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking.
|
||||||
|
- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review.
|
||||||
|
- Avoid expensive inline computations in templates; move to cached `computed` selectors.
|
||||||
|
- Keep state normalized; avoid duplicated derived state across multiple stores/components.
|
||||||
|
- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import.
|
||||||
|
- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`.
|
||||||
|
- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks.
|
||||||
|
- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes.
|
||||||
|
|
||||||
|
### Performance Budget & PR Evidence
|
||||||
|
- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline.
|
||||||
|
- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval.
|
||||||
|
- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table).
|
||||||
|
- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output).
|
||||||
|
|
||||||
|
### Quality Gate
|
||||||
|
- Any changed code must include new or updated unit tests.
|
||||||
|
- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules).
|
||||||
|
- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description.
|
||||||
|
|
||||||
|
## Testing Guidelines
|
||||||
|
- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant.
|
||||||
|
- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`.
|
||||||
|
- Enforce unit-test and coverage rules defined in `Quality Gate`.
|
||||||
|
- Before opening a PR, run `make test` plus targeted tests for touched areas.
|
||||||
|
|
||||||
|
## Commit & Pull Request Guidelines
|
||||||
|
- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`.
|
||||||
|
- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes.
|
||||||
|
- For behavior/API changes, add or update `openspec/changes/...` artifacts.
|
||||||
|
- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR.
|
||||||
|
|
||||||
|
## Security & Configuration Tips
|
||||||
|
- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials.
|
||||||
|
- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev.
|
||||||
13
Dockerfile
13
Dockerfile
@@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
ARG NODE_IMAGE=node:24-alpine
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
ARG GOLANG_IMAGE=golang:1.25.7-alpine
|
ARG GOLANG_IMAGE=golang:1.25.7-alpine
|
||||||
ARG ALPINE_IMAGE=alpine:3.20
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
ARG GOPROXY=https://goproxy.cn,direct
|
ARG GOPROXY=https://goproxy.cn,direct
|
||||||
ARG GOSUMDB=sum.golang.google.cn
|
ARG GOSUMDB=sum.golang.google.cn
|
||||||
|
|
||||||
@@ -68,6 +68,7 @@ RUN VERSION_VALUE="${VERSION}" && \
|
|||||||
CGO_ENABLED=0 GOOS=linux go build \
|
CGO_ENABLED=0 GOOS=linux go build \
|
||||||
-tags embed \
|
-tags embed \
|
||||||
-ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \
|
-ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \
|
||||||
|
-trimpath \
|
||||||
-o /app/sub2api \
|
-o /app/sub2api \
|
||||||
./cmd/server
|
./cmd/server
|
||||||
|
|
||||||
@@ -85,7 +86,6 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
|||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
curl \
|
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
@@ -95,11 +95,12 @@ RUN addgroup -g 1000 sub2api && \
|
|||||||
# Set working directory
|
# Set working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy binary from builder
|
# Copy binary/resources with ownership to avoid extra full-layer chown copy
|
||||||
COPY --from=backend-builder /app/sub2api /app/sub2api
|
COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api
|
||||||
|
COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources
|
||||||
|
|
||||||
# Create data directory
|
# Create data directory
|
||||||
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||||
|
|
||||||
# Switch to non-root user
|
# Switch to non-root user
|
||||||
USER sub2api
|
USER sub2api
|
||||||
@@ -109,7 +110,7 @@ EXPOSE 8080
|
|||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
# Run the application
|
# Run the application
|
||||||
ENTRYPOINT ["/app/sub2api"]
|
ENTRYPOINT ["/app/sub2api"]
|
||||||
|
|||||||
9
Makefile
9
Makefile
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: build build-backend build-frontend test test-backend test-frontend secret-scan
|
.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan
|
||||||
|
|
||||||
# 一键编译前后端
|
# 一键编译前后端
|
||||||
build: build-backend build-frontend
|
build: build-backend build-frontend
|
||||||
@@ -11,6 +11,10 @@ build-backend:
|
|||||||
build-frontend:
|
build-frontend:
|
||||||
@pnpm --dir frontend run build
|
@pnpm --dir frontend run build
|
||||||
|
|
||||||
|
# 编译 datamanagementd(宿主机数据管理进程)
|
||||||
|
build-datamanagementd:
|
||||||
|
@cd datamanagement && go build -o datamanagementd ./cmd/datamanagementd
|
||||||
|
|
||||||
# 运行测试(后端 + 前端)
|
# 运行测试(后端 + 前端)
|
||||||
test: test-backend test-frontend
|
test: test-backend test-frontend
|
||||||
|
|
||||||
@@ -21,5 +25,8 @@ test-frontend:
|
|||||||
@pnpm --dir frontend run lint:check
|
@pnpm --dir frontend run lint:check
|
||||||
@pnpm --dir frontend run typecheck
|
@pnpm --dir frontend run typecheck
|
||||||
|
|
||||||
|
test-datamanagementd:
|
||||||
|
@cd datamanagement && go test ./...
|
||||||
|
|
||||||
secret-scan:
|
secret-scan:
|
||||||
@python3 tools/secret_scan.py
|
@python3 tools/secret_scan.py
|
||||||
|
|||||||
28
README.md
28
README.md
@@ -57,6 +57,34 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Codex CLI WebSocket v2 Example
|
||||||
|
|
||||||
|
To enable OpenAI WebSocket Mode v2 in Codex CLI with Sub2API, add the following to `~/.codex/config.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
model_provider = "aicodx2api"
|
||||||
|
model = "gpt-5.3-codex"
|
||||||
|
review_model = "gpt-5.3-codex"
|
||||||
|
model_reasoning_effort = "xhigh"
|
||||||
|
disable_response_storage = true
|
||||||
|
network_access = "enabled"
|
||||||
|
windows_wsl_setup_acknowledged = true
|
||||||
|
|
||||||
|
[model_providers.aicodx2api]
|
||||||
|
name = "aicodx2api"
|
||||||
|
base_url = "https://api.sub2api.ai"
|
||||||
|
wire_api = "responses"
|
||||||
|
supports_websockets = true
|
||||||
|
requires_openai_auth = true
|
||||||
|
|
||||||
|
[features]
|
||||||
|
responses_websockets_v2 = true
|
||||||
|
```
|
||||||
|
|
||||||
|
After updating the config, restart Codex CLI.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Deployment
|
## Deployment
|
||||||
|
|
||||||
### Method 1: Script Installation (Recommended)
|
### Method 1: Script Installation (Recommended)
|
||||||
|
|||||||
38
README_CN.md
38
README_CN.md
@@ -62,6 +62,32 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
|
||||||
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
|
||||||
|
|
||||||
|
## Codex CLI 开启 OpenAI WebSocket Mode v2 示例配置
|
||||||
|
|
||||||
|
如需在 Codex CLI 中通过 Sub2API 启用 OpenAI WebSocket Mode v2,可将以下配置写入 `~/.codex/config.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
model_provider = "aicodx2api"
|
||||||
|
model = "gpt-5.3-codex"
|
||||||
|
review_model = "gpt-5.3-codex"
|
||||||
|
model_reasoning_effort = "xhigh"
|
||||||
|
disable_response_storage = true
|
||||||
|
network_access = "enabled"
|
||||||
|
windows_wsl_setup_acknowledged = true
|
||||||
|
|
||||||
|
[model_providers.aicodx2api]
|
||||||
|
name = "aicodx2api"
|
||||||
|
base_url = "https://api.sub2api.ai"
|
||||||
|
wire_api = "responses"
|
||||||
|
supports_websockets = true
|
||||||
|
requires_openai_auth = true
|
||||||
|
|
||||||
|
[features]
|
||||||
|
responses_websockets_v2 = true
|
||||||
|
```
|
||||||
|
|
||||||
|
配置更新后,重启 Codex CLI 使其生效。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
@@ -246,6 +272,18 @@ docker-compose -f docker-compose.local.yml logs -f sub2api
|
|||||||
|
|
||||||
**推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。
|
**推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。
|
||||||
|
|
||||||
|
#### 启用“数据管理”功能(datamanagementd)
|
||||||
|
|
||||||
|
如需启用管理后台“数据管理”,需要额外部署宿主机数据管理进程 `datamanagementd`。
|
||||||
|
|
||||||
|
关键点:
|
||||||
|
|
||||||
|
- 主进程固定探测:`/tmp/sub2api-datamanagement.sock`
|
||||||
|
- 只有该 Socket 可连通时,数据管理功能才会开启
|
||||||
|
- Docker 场景需将宿主机 Socket 挂载到容器同路径
|
||||||
|
|
||||||
|
详细部署步骤见:`deploy/DATAMANAGEMENTD_CN.md`
|
||||||
|
|
||||||
#### 访问
|
#### 访问
|
||||||
|
|
||||||
在浏览器中打开 `http://你的服务器IP:8080`
|
在浏览器中打开 `http://你的服务器IP:8080`
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"global": {
|
"global": {
|
||||||
"exclude": "G704"
|
"exclude": "G704,G101,G103,G104,G109,G115,G201,G202,G301,G302,G304,G306,G404"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,14 @@
|
|||||||
.PHONY: build test test-unit test-integration test-e2e
|
.PHONY: build generate test test-unit test-integration test-e2e
|
||||||
|
|
||||||
|
VERSION ?= $(shell tr -d '\r\n' < ./cmd/server/VERSION)
|
||||||
|
LDFLAGS ?= -s -w -X main.Version=$(VERSION)
|
||||||
|
|
||||||
build:
|
build:
|
||||||
go build -o bin/server ./cmd/server
|
CGO_ENABLED=0 go build -ldflags="$(LDFLAGS)" -trimpath -o bin/server ./cmd/server
|
||||||
|
|
||||||
|
generate:
|
||||||
|
go generate ./ent
|
||||||
|
go generate ./cmd/server
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test ./...
|
go test ./...
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
0.1.85
|
0.1.85.15
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/ent"
|
"github.com/Wei-Shaw/sub2api/ent"
|
||||||
@@ -84,16 +85,19 @@ func provideCleanup(
|
|||||||
openaiOAuth *service.OpenAIOAuthService,
|
openaiOAuth *service.OpenAIOAuthService,
|
||||||
geminiOAuth *service.GeminiOAuthService,
|
geminiOAuth *service.GeminiOAuthService,
|
||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Cleanup steps in reverse dependency order
|
type cleanupStep struct {
|
||||||
cleanupSteps := []struct {
|
|
||||||
name string
|
name string
|
||||||
fn func() error
|
fn func() error
|
||||||
}{
|
}
|
||||||
|
|
||||||
|
// 应用层清理步骤可并行执行,基础设施资源(Redis/Ent)最后按顺序关闭。
|
||||||
|
parallelSteps := []cleanupStep{
|
||||||
{"OpsScheduledReportService", func() error {
|
{"OpsScheduledReportService", func() error {
|
||||||
if opsScheduledReport != nil {
|
if opsScheduledReport != nil {
|
||||||
opsScheduledReport.Stop()
|
opsScheduledReport.Stop()
|
||||||
@@ -206,23 +210,60 @@ func provideCleanup(
|
|||||||
antigravityOAuth.Stop()
|
antigravityOAuth.Stop()
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"OpenAIWSPool", func() error {
|
||||||
|
if openAIGateway != nil {
|
||||||
|
openAIGateway.CloseOpenAIWSPool()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
infraSteps := []cleanupStep{
|
||||||
{"Redis", func() error {
|
{"Redis", func() error {
|
||||||
|
if rdb == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return rdb.Close()
|
return rdb.Close()
|
||||||
}},
|
}},
|
||||||
{"Ent", func() error {
|
{"Ent", func() error {
|
||||||
|
if entClient == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return entClient.Close()
|
return entClient.Close()
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, step := range cleanupSteps {
|
runParallel := func(steps []cleanupStep) {
|
||||||
if err := step.fn(); err != nil {
|
var wg sync.WaitGroup
|
||||||
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
for i := range steps {
|
||||||
// Continue with remaining cleanup steps even if one fails
|
step := steps[i]
|
||||||
} else {
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := step.fn(); err != nil {
|
||||||
|
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
runSequential := func(steps []cleanupStep) {
|
||||||
|
for i := range steps {
|
||||||
|
step := steps[i]
|
||||||
|
if err := step.fn(); err != nil {
|
||||||
|
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
log.Printf("[Cleanup] %s succeeded", step.name)
|
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
runParallel(parallelSteps)
|
||||||
|
runSequential(infraSteps)
|
||||||
|
|
||||||
// Check if context timed out
|
// Check if context timed out
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -139,6 +140,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
|
||||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||||
|
dataManagementService := service.NewDataManagementService()
|
||||||
|
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||||
@@ -163,7 +166,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
soraS3Storage := service.NewSoraS3Storage(settingService)
|
||||||
|
settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient)
|
||||||
|
soraGenerationRepository := repository.NewSoraGenerationRepository(db)
|
||||||
|
soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
|
||||||
|
soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
|
||||||
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage)
|
||||||
opsHandler := admin.NewOpsHandler(opsService)
|
opsHandler := admin.NewOpsHandler(opsService)
|
||||||
updateCache := repository.NewUpdateCache(redisClient)
|
updateCache := repository.NewUpdateCache(redisClient)
|
||||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||||
@@ -184,19 +192,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
|
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
|
||||||
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
|
||||||
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
|
||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
|
||||||
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||||
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, soraMediaStorage, rateLimitService, configConfig)
|
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
|
||||||
|
soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
|
||||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
|
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
totpHandler := handler.NewTotpHandler(totpService)
|
totpHandler := handler.NewTotpHandler(totpService)
|
||||||
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
|
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
|
||||||
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
|
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
|
||||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||||
@@ -211,7 +220,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -258,15 +267,18 @@ func provideCleanup(
|
|||||||
openaiOAuth *service.OpenAIOAuthService,
|
openaiOAuth *service.OpenAIOAuthService,
|
||||||
geminiOAuth *service.GeminiOAuthService,
|
geminiOAuth *service.GeminiOAuthService,
|
||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
cleanupSteps := []struct {
|
type cleanupStep struct {
|
||||||
name string
|
name string
|
||||||
fn func() error
|
fn func() error
|
||||||
}{
|
}
|
||||||
|
|
||||||
|
parallelSteps := []cleanupStep{
|
||||||
{"OpsScheduledReportService", func() error {
|
{"OpsScheduledReportService", func() error {
|
||||||
if opsScheduledReport != nil {
|
if opsScheduledReport != nil {
|
||||||
opsScheduledReport.Stop()
|
opsScheduledReport.Stop()
|
||||||
@@ -379,23 +391,60 @@ func provideCleanup(
|
|||||||
antigravityOAuth.Stop()
|
antigravityOAuth.Stop()
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"OpenAIWSPool", func() error {
|
||||||
|
if openAIGateway != nil {
|
||||||
|
openAIGateway.CloseOpenAIWSPool()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
infraSteps := []cleanupStep{
|
||||||
{"Redis", func() error {
|
{"Redis", func() error {
|
||||||
|
if rdb == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return rdb.Close()
|
return rdb.Close()
|
||||||
}},
|
}},
|
||||||
{"Ent", func() error {
|
{"Ent", func() error {
|
||||||
|
if entClient == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return entClient.Close()
|
return entClient.Close()
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, step := range cleanupSteps {
|
runParallel := func(steps []cleanupStep) {
|
||||||
if err := step.fn(); err != nil {
|
var wg sync.WaitGroup
|
||||||
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
for i := range steps {
|
||||||
|
step := steps[i]
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := step.fn(); err != nil {
|
||||||
|
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
runSequential := func(steps []cleanupStep) {
|
||||||
|
for i := range steps {
|
||||||
|
step := steps[i]
|
||||||
|
if err := step.fn(); err != nil {
|
||||||
|
log.Printf("[Cleanup] %s failed: %v", step.name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
log.Printf("[Cleanup] %s succeeded", step.name)
|
log.Printf("[Cleanup] %s succeeded", step.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
runParallel(parallelSteps)
|
||||||
|
runSequential(infraSteps)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
|
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
|
||||||
|
|||||||
81
backend/cmd/server/wire_gen_test.go
Normal file
81
backend/cmd/server/wire_gen_test.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProvideServiceBuildInfo(t *testing.T) {
|
||||||
|
in := handler.BuildInfo{
|
||||||
|
Version: "v-test",
|
||||||
|
BuildType: "release",
|
||||||
|
}
|
||||||
|
out := provideServiceBuildInfo(in)
|
||||||
|
require.Equal(t, in.Version, out.Version)
|
||||||
|
require.Equal(t, in.BuildType, out.BuildType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
|
||||||
|
oauthSvc := service.NewOAuthService(nil, nil)
|
||||||
|
openAIOAuthSvc := service.NewOpenAIOAuthService(nil, nil)
|
||||||
|
geminiOAuthSvc := service.NewGeminiOAuthService(nil, nil, nil, nil, cfg)
|
||||||
|
antigravityOAuthSvc := service.NewAntigravityOAuthService(nil)
|
||||||
|
|
||||||
|
tokenRefreshSvc := service.NewTokenRefreshService(
|
||||||
|
nil,
|
||||||
|
oauthSvc,
|
||||||
|
openAIOAuthSvc,
|
||||||
|
geminiOAuthSvc,
|
||||||
|
antigravityOAuthSvc,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
cfg,
|
||||||
|
)
|
||||||
|
accountExpirySvc := service.NewAccountExpiryService(nil, time.Second)
|
||||||
|
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
|
||||||
|
pricingSvc := service.NewPricingService(cfg, nil)
|
||||||
|
emailQueueSvc := service.NewEmailQueueService(nil, 1)
|
||||||
|
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||||
|
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
||||||
|
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
||||||
|
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
||||||
|
|
||||||
|
cleanup := provideCleanup(
|
||||||
|
nil, // entClient
|
||||||
|
nil, // redis
|
||||||
|
&service.OpsMetricsCollector{},
|
||||||
|
&service.OpsAggregationService{},
|
||||||
|
&service.OpsAlertEvaluatorService{},
|
||||||
|
&service.OpsCleanupService{},
|
||||||
|
&service.OpsScheduledReportService{},
|
||||||
|
opsSystemLogSinkSvc,
|
||||||
|
&service.SoraMediaCleanupService{},
|
||||||
|
schedulerSnapshotSvc,
|
||||||
|
tokenRefreshSvc,
|
||||||
|
accountExpirySvc,
|
||||||
|
subscriptionExpirySvc,
|
||||||
|
&service.UsageCleanupService{},
|
||||||
|
idempotencyCleanupSvc,
|
||||||
|
pricingSvc,
|
||||||
|
emailQueueSvc,
|
||||||
|
billingCacheSvc,
|
||||||
|
&service.UsageRecordWorkerPool{},
|
||||||
|
&service.SubscriptionService{},
|
||||||
|
oauthSvc,
|
||||||
|
openAIOAuthSvc,
|
||||||
|
geminiOAuthSvc,
|
||||||
|
antigravityOAuthSvc,
|
||||||
|
nil, // openAIGateway
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
cleanup()
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -63,6 +63,10 @@ type Account struct {
|
|||||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at,omitempty"`
|
RateLimitResetAt *time.Time `json:"rate_limit_reset_at,omitempty"`
|
||||||
// OverloadUntil holds the value of the "overload_until" field.
|
// OverloadUntil holds the value of the "overload_until" field.
|
||||||
OverloadUntil *time.Time `json:"overload_until,omitempty"`
|
OverloadUntil *time.Time `json:"overload_until,omitempty"`
|
||||||
|
// TempUnschedulableUntil holds the value of the "temp_unschedulable_until" field.
|
||||||
|
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
||||||
|
// TempUnschedulableReason holds the value of the "temp_unschedulable_reason" field.
|
||||||
|
TempUnschedulableReason *string `json:"temp_unschedulable_reason,omitempty"`
|
||||||
// SessionWindowStart holds the value of the "session_window_start" field.
|
// SessionWindowStart holds the value of the "session_window_start" field.
|
||||||
SessionWindowStart *time.Time `json:"session_window_start,omitempty"`
|
SessionWindowStart *time.Time `json:"session_window_start,omitempty"`
|
||||||
// SessionWindowEnd holds the value of the "session_window_end" field.
|
// SessionWindowEnd holds the value of the "session_window_end" field.
|
||||||
@@ -141,9 +145,9 @@ func (*Account) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
|
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd:
|
case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldTempUnschedulableUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
default:
|
default:
|
||||||
values[i] = new(sql.UnknownType)
|
values[i] = new(sql.UnknownType)
|
||||||
@@ -311,6 +315,20 @@ func (_m *Account) assignValues(columns []string, values []any) error {
|
|||||||
_m.OverloadUntil = new(time.Time)
|
_m.OverloadUntil = new(time.Time)
|
||||||
*_m.OverloadUntil = value.Time
|
*_m.OverloadUntil = value.Time
|
||||||
}
|
}
|
||||||
|
case account.FieldTempUnschedulableUntil:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field temp_unschedulable_until", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.TempUnschedulableUntil = new(time.Time)
|
||||||
|
*_m.TempUnschedulableUntil = value.Time
|
||||||
|
}
|
||||||
|
case account.FieldTempUnschedulableReason:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field temp_unschedulable_reason", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.TempUnschedulableReason = new(string)
|
||||||
|
*_m.TempUnschedulableReason = value.String
|
||||||
|
}
|
||||||
case account.FieldSessionWindowStart:
|
case account.FieldSessionWindowStart:
|
||||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field session_window_start", values[i])
|
return fmt.Errorf("unexpected type %T for field session_window_start", values[i])
|
||||||
@@ -472,6 +490,16 @@ func (_m *Account) String() string {
|
|||||||
builder.WriteString(v.Format(time.ANSIC))
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
if v := _m.TempUnschedulableUntil; v != nil {
|
||||||
|
builder.WriteString("temp_unschedulable_until=")
|
||||||
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.TempUnschedulableReason; v != nil {
|
||||||
|
builder.WriteString("temp_unschedulable_reason=")
|
||||||
|
builder.WriteString(*v)
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
if v := _m.SessionWindowStart; v != nil {
|
if v := _m.SessionWindowStart; v != nil {
|
||||||
builder.WriteString("session_window_start=")
|
builder.WriteString("session_window_start=")
|
||||||
builder.WriteString(v.Format(time.ANSIC))
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
|
|||||||
@@ -59,6 +59,10 @@ const (
|
|||||||
FieldRateLimitResetAt = "rate_limit_reset_at"
|
FieldRateLimitResetAt = "rate_limit_reset_at"
|
||||||
// FieldOverloadUntil holds the string denoting the overload_until field in the database.
|
// FieldOverloadUntil holds the string denoting the overload_until field in the database.
|
||||||
FieldOverloadUntil = "overload_until"
|
FieldOverloadUntil = "overload_until"
|
||||||
|
// FieldTempUnschedulableUntil holds the string denoting the temp_unschedulable_until field in the database.
|
||||||
|
FieldTempUnschedulableUntil = "temp_unschedulable_until"
|
||||||
|
// FieldTempUnschedulableReason holds the string denoting the temp_unschedulable_reason field in the database.
|
||||||
|
FieldTempUnschedulableReason = "temp_unschedulable_reason"
|
||||||
// FieldSessionWindowStart holds the string denoting the session_window_start field in the database.
|
// FieldSessionWindowStart holds the string denoting the session_window_start field in the database.
|
||||||
FieldSessionWindowStart = "session_window_start"
|
FieldSessionWindowStart = "session_window_start"
|
||||||
// FieldSessionWindowEnd holds the string denoting the session_window_end field in the database.
|
// FieldSessionWindowEnd holds the string denoting the session_window_end field in the database.
|
||||||
@@ -128,6 +132,8 @@ var Columns = []string{
|
|||||||
FieldRateLimitedAt,
|
FieldRateLimitedAt,
|
||||||
FieldRateLimitResetAt,
|
FieldRateLimitResetAt,
|
||||||
FieldOverloadUntil,
|
FieldOverloadUntil,
|
||||||
|
FieldTempUnschedulableUntil,
|
||||||
|
FieldTempUnschedulableReason,
|
||||||
FieldSessionWindowStart,
|
FieldSessionWindowStart,
|
||||||
FieldSessionWindowEnd,
|
FieldSessionWindowEnd,
|
||||||
FieldSessionWindowStatus,
|
FieldSessionWindowStatus,
|
||||||
@@ -299,6 +305,16 @@ func ByOverloadUntil(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldOverloadUntil, opts...).ToFunc()
|
return sql.OrderByField(FieldOverloadUntil, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByTempUnschedulableUntil orders the results by the temp_unschedulable_until field.
|
||||||
|
func ByTempUnschedulableUntil(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldTempUnschedulableUntil, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByTempUnschedulableReason orders the results by the temp_unschedulable_reason field.
|
||||||
|
func ByTempUnschedulableReason(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldTempUnschedulableReason, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// BySessionWindowStart orders the results by the session_window_start field.
|
// BySessionWindowStart orders the results by the session_window_start field.
|
||||||
func BySessionWindowStart(opts ...sql.OrderTermOption) OrderOption {
|
func BySessionWindowStart(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldSessionWindowStart, opts...).ToFunc()
|
return sql.OrderByField(FieldSessionWindowStart, opts...).ToFunc()
|
||||||
|
|||||||
@@ -155,6 +155,16 @@ func OverloadUntil(v time.Time) predicate.Account {
|
|||||||
return predicate.Account(sql.FieldEQ(FieldOverloadUntil, v))
|
return predicate.Account(sql.FieldEQ(FieldOverloadUntil, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntil applies equality check predicate on the "temp_unschedulable_until" field. It's identical to TempUnschedulableUntilEQ.
|
||||||
|
func TempUnschedulableUntil(v time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReason applies equality check predicate on the "temp_unschedulable_reason" field. It's identical to TempUnschedulableReasonEQ.
|
||||||
|
func TempUnschedulableReason(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
// SessionWindowStart applies equality check predicate on the "session_window_start" field. It's identical to SessionWindowStartEQ.
|
// SessionWindowStart applies equality check predicate on the "session_window_start" field. It's identical to SessionWindowStartEQ.
|
||||||
func SessionWindowStart(v time.Time) predicate.Account {
|
func SessionWindowStart(v time.Time) predicate.Account {
|
||||||
return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v))
|
return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v))
|
||||||
@@ -1130,6 +1140,131 @@ func OverloadUntilNotNil() predicate.Account {
|
|||||||
return predicate.Account(sql.FieldNotNull(FieldOverloadUntil))
|
return predicate.Account(sql.FieldNotNull(FieldOverloadUntil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilEQ applies the EQ predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilEQ(v time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilNEQ applies the NEQ predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilNEQ(v time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilIn applies the In predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilIn(vs ...time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldIn(FieldTempUnschedulableUntil, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilNotIn applies the NotIn predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilNotIn(vs ...time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableUntil, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilGT applies the GT predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilGT(v time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldGT(FieldTempUnschedulableUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilGTE applies the GTE predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilGTE(v time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldGTE(FieldTempUnschedulableUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilLT applies the LT predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilLT(v time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldLT(FieldTempUnschedulableUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilLTE applies the LTE predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilLTE(v time.Time) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldLTE(FieldTempUnschedulableUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilIsNil applies the IsNil predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilIsNil() predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableUntil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableUntilNotNil applies the NotNil predicate on the "temp_unschedulable_until" field.
|
||||||
|
func TempUnschedulableUntilNotNil() predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableUntil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonEQ applies the EQ predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonEQ(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonNEQ applies the NEQ predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonNEQ(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonIn applies the In predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonIn(vs ...string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldIn(FieldTempUnschedulableReason, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonNotIn applies the NotIn predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonNotIn(vs ...string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableReason, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonGT applies the GT predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonGT(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldGT(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonGTE applies the GTE predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonGTE(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldGTE(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonLT applies the LT predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonLT(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldLT(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonLTE applies the LTE predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonLTE(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldLTE(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonContains applies the Contains predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonContains(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldContains(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonHasPrefix applies the HasPrefix predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonHasPrefix(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldHasPrefix(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonHasSuffix applies the HasSuffix predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonHasSuffix(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldHasSuffix(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonIsNil applies the IsNil predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonIsNil() predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableReason))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonNotNil applies the NotNil predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonNotNil() predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableReason))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonEqualFold applies the EqualFold predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonEqualFold(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldEqualFold(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUnschedulableReasonContainsFold applies the ContainsFold predicate on the "temp_unschedulable_reason" field.
|
||||||
|
func TempUnschedulableReasonContainsFold(v string) predicate.Account {
|
||||||
|
return predicate.Account(sql.FieldContainsFold(FieldTempUnschedulableReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
// SessionWindowStartEQ applies the EQ predicate on the "session_window_start" field.
|
// SessionWindowStartEQ applies the EQ predicate on the "session_window_start" field.
|
||||||
func SessionWindowStartEQ(v time.Time) predicate.Account {
|
func SessionWindowStartEQ(v time.Time) predicate.Account {
|
||||||
return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v))
|
return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v))
|
||||||
|
|||||||
@@ -293,6 +293,34 @@ func (_c *AccountCreate) SetNillableOverloadUntil(v *time.Time) *AccountCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
|
||||||
|
func (_c *AccountCreate) SetTempUnschedulableUntil(v time.Time) *AccountCreate {
|
||||||
|
_c.mutation.SetTempUnschedulableUntil(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil.
|
||||||
|
func (_c *AccountCreate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetTempUnschedulableUntil(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
|
||||||
|
func (_c *AccountCreate) SetTempUnschedulableReason(v string) *AccountCreate {
|
||||||
|
_c.mutation.SetTempUnschedulableReason(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil.
|
||||||
|
func (_c *AccountCreate) SetNillableTempUnschedulableReason(v *string) *AccountCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetTempUnschedulableReason(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetSessionWindowStart sets the "session_window_start" field.
|
// SetSessionWindowStart sets the "session_window_start" field.
|
||||||
func (_c *AccountCreate) SetSessionWindowStart(v time.Time) *AccountCreate {
|
func (_c *AccountCreate) SetSessionWindowStart(v time.Time) *AccountCreate {
|
||||||
_c.mutation.SetSessionWindowStart(v)
|
_c.mutation.SetSessionWindowStart(v)
|
||||||
@@ -639,6 +667,14 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(account.FieldOverloadUntil, field.TypeTime, value)
|
_spec.SetField(account.FieldOverloadUntil, field.TypeTime, value)
|
||||||
_node.OverloadUntil = &value
|
_node.OverloadUntil = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.TempUnschedulableUntil(); ok {
|
||||||
|
_spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value)
|
||||||
|
_node.TempUnschedulableUntil = &value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.TempUnschedulableReason(); ok {
|
||||||
|
_spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value)
|
||||||
|
_node.TempUnschedulableReason = &value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.SessionWindowStart(); ok {
|
if value, ok := _c.mutation.SessionWindowStart(); ok {
|
||||||
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
|
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
|
||||||
_node.SessionWindowStart = &value
|
_node.SessionWindowStart = &value
|
||||||
@@ -1080,6 +1116,42 @@ func (u *AccountUpsert) ClearOverloadUntil() *AccountUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
|
||||||
|
func (u *AccountUpsert) SetTempUnschedulableUntil(v time.Time) *AccountUpsert {
|
||||||
|
u.Set(account.FieldTempUnschedulableUntil, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsert) UpdateTempUnschedulableUntil() *AccountUpsert {
|
||||||
|
u.SetExcluded(account.FieldTempUnschedulableUntil)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
|
||||||
|
func (u *AccountUpsert) ClearTempUnschedulableUntil() *AccountUpsert {
|
||||||
|
u.SetNull(account.FieldTempUnschedulableUntil)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
|
||||||
|
func (u *AccountUpsert) SetTempUnschedulableReason(v string) *AccountUpsert {
|
||||||
|
u.Set(account.FieldTempUnschedulableReason, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsert) UpdateTempUnschedulableReason() *AccountUpsert {
|
||||||
|
u.SetExcluded(account.FieldTempUnschedulableReason)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
|
||||||
|
func (u *AccountUpsert) ClearTempUnschedulableReason() *AccountUpsert {
|
||||||
|
u.SetNull(account.FieldTempUnschedulableReason)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetSessionWindowStart sets the "session_window_start" field.
|
// SetSessionWindowStart sets the "session_window_start" field.
|
||||||
func (u *AccountUpsert) SetSessionWindowStart(v time.Time) *AccountUpsert {
|
func (u *AccountUpsert) SetSessionWindowStart(v time.Time) *AccountUpsert {
|
||||||
u.Set(account.FieldSessionWindowStart, v)
|
u.Set(account.FieldSessionWindowStart, v)
|
||||||
@@ -1557,6 +1629,48 @@ func (u *AccountUpsertOne) ClearOverloadUntil() *AccountUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
|
||||||
|
func (u *AccountUpsertOne) SetTempUnschedulableUntil(v time.Time) *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.SetTempUnschedulableUntil(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsertOne) UpdateTempUnschedulableUntil() *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.UpdateTempUnschedulableUntil()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
|
||||||
|
func (u *AccountUpsertOne) ClearTempUnschedulableUntil() *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.ClearTempUnschedulableUntil()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
|
||||||
|
func (u *AccountUpsertOne) SetTempUnschedulableReason(v string) *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.SetTempUnschedulableReason(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsertOne) UpdateTempUnschedulableReason() *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.UpdateTempUnschedulableReason()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
|
||||||
|
func (u *AccountUpsertOne) ClearTempUnschedulableReason() *AccountUpsertOne {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.ClearTempUnschedulableReason()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetSessionWindowStart sets the "session_window_start" field.
|
// SetSessionWindowStart sets the "session_window_start" field.
|
||||||
func (u *AccountUpsertOne) SetSessionWindowStart(v time.Time) *AccountUpsertOne {
|
func (u *AccountUpsertOne) SetSessionWindowStart(v time.Time) *AccountUpsertOne {
|
||||||
return u.Update(func(s *AccountUpsert) {
|
return u.Update(func(s *AccountUpsert) {
|
||||||
@@ -2209,6 +2323,48 @@ func (u *AccountUpsertBulk) ClearOverloadUntil() *AccountUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
|
||||||
|
func (u *AccountUpsertBulk) SetTempUnschedulableUntil(v time.Time) *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.SetTempUnschedulableUntil(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsertBulk) UpdateTempUnschedulableUntil() *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.UpdateTempUnschedulableUntil()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
|
||||||
|
func (u *AccountUpsertBulk) ClearTempUnschedulableUntil() *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.ClearTempUnschedulableUntil()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
|
||||||
|
func (u *AccountUpsertBulk) SetTempUnschedulableReason(v string) *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.SetTempUnschedulableReason(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create.
|
||||||
|
func (u *AccountUpsertBulk) UpdateTempUnschedulableReason() *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.UpdateTempUnschedulableReason()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
|
||||||
|
func (u *AccountUpsertBulk) ClearTempUnschedulableReason() *AccountUpsertBulk {
|
||||||
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
s.ClearTempUnschedulableReason()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetSessionWindowStart sets the "session_window_start" field.
|
// SetSessionWindowStart sets the "session_window_start" field.
|
||||||
func (u *AccountUpsertBulk) SetSessionWindowStart(v time.Time) *AccountUpsertBulk {
|
func (u *AccountUpsertBulk) SetSessionWindowStart(v time.Time) *AccountUpsertBulk {
|
||||||
return u.Update(func(s *AccountUpsert) {
|
return u.Update(func(s *AccountUpsert) {
|
||||||
|
|||||||
@@ -376,6 +376,46 @@ func (_u *AccountUpdate) ClearOverloadUntil() *AccountUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
|
||||||
|
func (_u *AccountUpdate) SetTempUnschedulableUntil(v time.Time) *AccountUpdate {
|
||||||
|
_u.mutation.SetTempUnschedulableUntil(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil.
|
||||||
|
func (_u *AccountUpdate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTempUnschedulableUntil(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
|
||||||
|
func (_u *AccountUpdate) ClearTempUnschedulableUntil() *AccountUpdate {
|
||||||
|
_u.mutation.ClearTempUnschedulableUntil()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
|
||||||
|
func (_u *AccountUpdate) SetTempUnschedulableReason(v string) *AccountUpdate {
|
||||||
|
_u.mutation.SetTempUnschedulableReason(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil.
|
||||||
|
func (_u *AccountUpdate) SetNillableTempUnschedulableReason(v *string) *AccountUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTempUnschedulableReason(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
|
||||||
|
func (_u *AccountUpdate) ClearTempUnschedulableReason() *AccountUpdate {
|
||||||
|
_u.mutation.ClearTempUnschedulableReason()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetSessionWindowStart sets the "session_window_start" field.
|
// SetSessionWindowStart sets the "session_window_start" field.
|
||||||
func (_u *AccountUpdate) SetSessionWindowStart(v time.Time) *AccountUpdate {
|
func (_u *AccountUpdate) SetSessionWindowStart(v time.Time) *AccountUpdate {
|
||||||
_u.mutation.SetSessionWindowStart(v)
|
_u.mutation.SetSessionWindowStart(v)
|
||||||
@@ -701,6 +741,18 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.OverloadUntilCleared() {
|
if _u.mutation.OverloadUntilCleared() {
|
||||||
_spec.ClearField(account.FieldOverloadUntil, field.TypeTime)
|
_spec.ClearField(account.FieldOverloadUntil, field.TypeTime)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.TempUnschedulableUntil(); ok {
|
||||||
|
_spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.TempUnschedulableUntilCleared() {
|
||||||
|
_spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.TempUnschedulableReason(); ok {
|
||||||
|
_spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.TempUnschedulableReasonCleared() {
|
||||||
|
_spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.SessionWindowStart(); ok {
|
if value, ok := _u.mutation.SessionWindowStart(); ok {
|
||||||
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
|
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
|
||||||
}
|
}
|
||||||
@@ -1215,6 +1267,46 @@ func (_u *AccountUpdateOne) ClearOverloadUntil() *AccountUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
|
||||||
|
func (_u *AccountUpdateOne) SetTempUnschedulableUntil(v time.Time) *AccountUpdateOne {
|
||||||
|
_u.mutation.SetTempUnschedulableUntil(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil.
|
||||||
|
func (_u *AccountUpdateOne) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTempUnschedulableUntil(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
|
||||||
|
func (_u *AccountUpdateOne) ClearTempUnschedulableUntil() *AccountUpdateOne {
|
||||||
|
_u.mutation.ClearTempUnschedulableUntil()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
|
||||||
|
func (_u *AccountUpdateOne) SetTempUnschedulableReason(v string) *AccountUpdateOne {
|
||||||
|
_u.mutation.SetTempUnschedulableReason(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil.
|
||||||
|
func (_u *AccountUpdateOne) SetNillableTempUnschedulableReason(v *string) *AccountUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTempUnschedulableReason(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
|
||||||
|
func (_u *AccountUpdateOne) ClearTempUnschedulableReason() *AccountUpdateOne {
|
||||||
|
_u.mutation.ClearTempUnschedulableReason()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetSessionWindowStart sets the "session_window_start" field.
|
// SetSessionWindowStart sets the "session_window_start" field.
|
||||||
func (_u *AccountUpdateOne) SetSessionWindowStart(v time.Time) *AccountUpdateOne {
|
func (_u *AccountUpdateOne) SetSessionWindowStart(v time.Time) *AccountUpdateOne {
|
||||||
_u.mutation.SetSessionWindowStart(v)
|
_u.mutation.SetSessionWindowStart(v)
|
||||||
@@ -1570,6 +1662,18 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
|
|||||||
if _u.mutation.OverloadUntilCleared() {
|
if _u.mutation.OverloadUntilCleared() {
|
||||||
_spec.ClearField(account.FieldOverloadUntil, field.TypeTime)
|
_spec.ClearField(account.FieldOverloadUntil, field.TypeTime)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.TempUnschedulableUntil(); ok {
|
||||||
|
_spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.TempUnschedulableUntilCleared() {
|
||||||
|
_spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.TempUnschedulableReason(); ok {
|
||||||
|
_spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.TempUnschedulableReasonCleared() {
|
||||||
|
_spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.SessionWindowStart(); ok {
|
if value, ok := _u.mutation.SessionWindowStart(); ok {
|
||||||
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
|
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||||
@@ -58,6 +59,8 @@ type Client struct {
|
|||||||
ErrorPassthroughRule *ErrorPassthroughRuleClient
|
ErrorPassthroughRule *ErrorPassthroughRuleClient
|
||||||
// Group is the client for interacting with the Group builders.
|
// Group is the client for interacting with the Group builders.
|
||||||
Group *GroupClient
|
Group *GroupClient
|
||||||
|
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
|
||||||
|
IdempotencyRecord *IdempotencyRecordClient
|
||||||
// PromoCode is the client for interacting with the PromoCode builders.
|
// PromoCode is the client for interacting with the PromoCode builders.
|
||||||
PromoCode *PromoCodeClient
|
PromoCode *PromoCodeClient
|
||||||
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
|
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
|
||||||
@@ -102,6 +105,7 @@ func (c *Client) init() {
|
|||||||
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
|
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
|
||||||
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
|
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
|
||||||
c.Group = NewGroupClient(c.config)
|
c.Group = NewGroupClient(c.config)
|
||||||
|
c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
|
||||||
c.PromoCode = NewPromoCodeClient(c.config)
|
c.PromoCode = NewPromoCodeClient(c.config)
|
||||||
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
|
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
|
||||||
c.Proxy = NewProxyClient(c.config)
|
c.Proxy = NewProxyClient(c.config)
|
||||||
@@ -214,6 +218,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
|||||||
AnnouncementRead: NewAnnouncementReadClient(cfg),
|
AnnouncementRead: NewAnnouncementReadClient(cfg),
|
||||||
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
|
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
|
||||||
Group: NewGroupClient(cfg),
|
Group: NewGroupClient(cfg),
|
||||||
|
IdempotencyRecord: NewIdempotencyRecordClient(cfg),
|
||||||
PromoCode: NewPromoCodeClient(cfg),
|
PromoCode: NewPromoCodeClient(cfg),
|
||||||
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
|
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
|
||||||
Proxy: NewProxyClient(cfg),
|
Proxy: NewProxyClient(cfg),
|
||||||
@@ -253,6 +258,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
|
|||||||
AnnouncementRead: NewAnnouncementReadClient(cfg),
|
AnnouncementRead: NewAnnouncementReadClient(cfg),
|
||||||
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
|
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
|
||||||
Group: NewGroupClient(cfg),
|
Group: NewGroupClient(cfg),
|
||||||
|
IdempotencyRecord: NewIdempotencyRecordClient(cfg),
|
||||||
PromoCode: NewPromoCodeClient(cfg),
|
PromoCode: NewPromoCodeClient(cfg),
|
||||||
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
|
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
|
||||||
Proxy: NewProxyClient(cfg),
|
Proxy: NewProxyClient(cfg),
|
||||||
@@ -296,10 +302,10 @@ func (c *Client) Close() error {
|
|||||||
func (c *Client) Use(hooks ...Hook) {
|
func (c *Client) Use(hooks ...Hook) {
|
||||||
for _, n := range []interface{ Use(...Hook) }{
|
for _, n := range []interface{ Use(...Hook) }{
|
||||||
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
|
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
|
||||||
c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
|
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode,
|
||||||
c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog,
|
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
|
||||||
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||||
c.UserSubscription,
|
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||||
} {
|
} {
|
||||||
n.Use(hooks...)
|
n.Use(hooks...)
|
||||||
}
|
}
|
||||||
@@ -310,10 +316,10 @@ func (c *Client) Use(hooks ...Hook) {
|
|||||||
func (c *Client) Intercept(interceptors ...Interceptor) {
|
func (c *Client) Intercept(interceptors ...Interceptor) {
|
||||||
for _, n := range []interface{ Intercept(...Interceptor) }{
|
for _, n := range []interface{ Intercept(...Interceptor) }{
|
||||||
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
|
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
|
||||||
c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
|
c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode,
|
||||||
c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog,
|
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
|
||||||
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
|
||||||
c.UserSubscription,
|
c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
|
||||||
} {
|
} {
|
||||||
n.Intercept(interceptors...)
|
n.Intercept(interceptors...)
|
||||||
}
|
}
|
||||||
@@ -336,6 +342,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
|
|||||||
return c.ErrorPassthroughRule.mutate(ctx, m)
|
return c.ErrorPassthroughRule.mutate(ctx, m)
|
||||||
case *GroupMutation:
|
case *GroupMutation:
|
||||||
return c.Group.mutate(ctx, m)
|
return c.Group.mutate(ctx, m)
|
||||||
|
case *IdempotencyRecordMutation:
|
||||||
|
return c.IdempotencyRecord.mutate(ctx, m)
|
||||||
case *PromoCodeMutation:
|
case *PromoCodeMutation:
|
||||||
return c.PromoCode.mutate(ctx, m)
|
return c.PromoCode.mutate(ctx, m)
|
||||||
case *PromoCodeUsageMutation:
|
case *PromoCodeUsageMutation:
|
||||||
@@ -1575,6 +1583,139 @@ func (c *GroupClient) mutate(ctx context.Context, m *GroupMutation) (Value, erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IdempotencyRecordClient is a client for the IdempotencyRecord schema.
|
||||||
|
type IdempotencyRecordClient struct {
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIdempotencyRecordClient returns a client for the IdempotencyRecord from the given config.
|
||||||
|
func NewIdempotencyRecordClient(c config) *IdempotencyRecordClient {
|
||||||
|
return &IdempotencyRecordClient{config: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use adds a list of mutation hooks to the hooks stack.
|
||||||
|
// A call to `Use(f, g, h)` equals to `idempotencyrecord.Hooks(f(g(h())))`.
|
||||||
|
func (c *IdempotencyRecordClient) Use(hooks ...Hook) {
|
||||||
|
c.hooks.IdempotencyRecord = append(c.hooks.IdempotencyRecord, hooks...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||||
|
// A call to `Intercept(f, g, h)` equals to `idempotencyrecord.Intercept(f(g(h())))`.
|
||||||
|
func (c *IdempotencyRecordClient) Intercept(interceptors ...Interceptor) {
|
||||||
|
c.inters.IdempotencyRecord = append(c.inters.IdempotencyRecord, interceptors...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create returns a builder for creating a IdempotencyRecord entity.
|
||||||
|
func (c *IdempotencyRecordClient) Create() *IdempotencyRecordCreate {
|
||||||
|
mutation := newIdempotencyRecordMutation(c.config, OpCreate)
|
||||||
|
return &IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateBulk returns a builder for creating a bulk of IdempotencyRecord entities.
|
||||||
|
func (c *IdempotencyRecordClient) CreateBulk(builders ...*IdempotencyRecordCreate) *IdempotencyRecordCreateBulk {
|
||||||
|
return &IdempotencyRecordCreateBulk{config: c.config, builders: builders}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||||
|
// a builder and applies setFunc on it.
|
||||||
|
func (c *IdempotencyRecordClient) MapCreateBulk(slice any, setFunc func(*IdempotencyRecordCreate, int)) *IdempotencyRecordCreateBulk {
|
||||||
|
rv := reflect.ValueOf(slice)
|
||||||
|
if rv.Kind() != reflect.Slice {
|
||||||
|
return &IdempotencyRecordCreateBulk{err: fmt.Errorf("calling to IdempotencyRecordClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||||
|
}
|
||||||
|
builders := make([]*IdempotencyRecordCreate, rv.Len())
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
builders[i] = c.Create()
|
||||||
|
setFunc(builders[i], i)
|
||||||
|
}
|
||||||
|
return &IdempotencyRecordCreateBulk{config: c.config, builders: builders}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update returns an update builder for IdempotencyRecord.
|
||||||
|
func (c *IdempotencyRecordClient) Update() *IdempotencyRecordUpdate {
|
||||||
|
mutation := newIdempotencyRecordMutation(c.config, OpUpdate)
|
||||||
|
return &IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOne returns an update builder for the given entity.
|
||||||
|
func (c *IdempotencyRecordClient) UpdateOne(_m *IdempotencyRecord) *IdempotencyRecordUpdateOne {
|
||||||
|
mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecord(_m))
|
||||||
|
return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOneID returns an update builder for the given id.
|
||||||
|
func (c *IdempotencyRecordClient) UpdateOneID(id int64) *IdempotencyRecordUpdateOne {
|
||||||
|
mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecordID(id))
|
||||||
|
return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete returns a delete builder for IdempotencyRecord.
|
||||||
|
func (c *IdempotencyRecordClient) Delete() *IdempotencyRecordDelete {
|
||||||
|
mutation := newIdempotencyRecordMutation(c.config, OpDelete)
|
||||||
|
return &IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOne returns a builder for deleting the given entity.
|
||||||
|
func (c *IdempotencyRecordClient) DeleteOne(_m *IdempotencyRecord) *IdempotencyRecordDeleteOne {
|
||||||
|
return c.DeleteOneID(_m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||||
|
func (c *IdempotencyRecordClient) DeleteOneID(id int64) *IdempotencyRecordDeleteOne {
|
||||||
|
builder := c.Delete().Where(idempotencyrecord.ID(id))
|
||||||
|
builder.mutation.id = &id
|
||||||
|
builder.mutation.op = OpDeleteOne
|
||||||
|
return &IdempotencyRecordDeleteOne{builder}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query returns a query builder for IdempotencyRecord.
|
||||||
|
func (c *IdempotencyRecordClient) Query() *IdempotencyRecordQuery {
|
||||||
|
return &IdempotencyRecordQuery{
|
||||||
|
config: c.config,
|
||||||
|
ctx: &QueryContext{Type: TypeIdempotencyRecord},
|
||||||
|
inters: c.Interceptors(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a IdempotencyRecord entity by its id.
|
||||||
|
func (c *IdempotencyRecordClient) Get(ctx context.Context, id int64) (*IdempotencyRecord, error) {
|
||||||
|
return c.Query().Where(idempotencyrecord.ID(id)).Only(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetX is like Get, but panics if an error occurs.
|
||||||
|
func (c *IdempotencyRecordClient) GetX(ctx context.Context, id int64) *IdempotencyRecord {
|
||||||
|
obj, err := c.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hooks returns the client hooks.
|
||||||
|
func (c *IdempotencyRecordClient) Hooks() []Hook {
|
||||||
|
return c.hooks.IdempotencyRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interceptors returns the client interceptors.
|
||||||
|
func (c *IdempotencyRecordClient) Interceptors() []Interceptor {
|
||||||
|
return c.inters.IdempotencyRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyRecordMutation) (Value, error) {
|
||||||
|
switch m.Op() {
|
||||||
|
case OpCreate:
|
||||||
|
return (&IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpUpdate:
|
||||||
|
return (&IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpUpdateOne:
|
||||||
|
return (&IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||||
|
case OpDelete, OpDeleteOne:
|
||||||
|
return (&IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("ent: unknown IdempotencyRecord mutation op: %q", m.Op())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// PromoCodeClient is a client for the PromoCode schema.
|
// PromoCodeClient is a client for the PromoCode schema.
|
||||||
type PromoCodeClient struct {
|
type PromoCodeClient struct {
|
||||||
config
|
config
|
||||||
@@ -3747,15 +3888,17 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
|
|||||||
type (
|
type (
|
||||||
hooks struct {
|
hooks struct {
|
||||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
|
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
|
||||||
ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
|
ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage,
|
||||||
SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User,
|
||||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
|
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||||
|
UserSubscription []ent.Hook
|
||||||
}
|
}
|
||||||
inters struct {
|
inters struct {
|
||||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
|
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
|
||||||
ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
|
ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage,
|
||||||
SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User,
|
||||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
|
UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
|
||||||
|
UserSubscription []ent.Interceptor
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||||
@@ -99,6 +100,7 @@ func checkColumn(t, c string) error {
|
|||||||
announcementread.Table: announcementread.ValidColumn,
|
announcementread.Table: announcementread.ValidColumn,
|
||||||
errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
|
errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
|
||||||
group.Table: group.ValidColumn,
|
group.Table: group.ValidColumn,
|
||||||
|
idempotencyrecord.Table: idempotencyrecord.ValidColumn,
|
||||||
promocode.Table: promocode.ValidColumn,
|
promocode.Table: promocode.ValidColumn,
|
||||||
promocodeusage.Table: promocodeusage.ValidColumn,
|
promocodeusage.Table: promocodeusage.ValidColumn,
|
||||||
proxy.Table: proxy.ValidColumn,
|
proxy.Table: proxy.ValidColumn,
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ type Group struct {
|
|||||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
||||||
// SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
|
// SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
|
||||||
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||||
|
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||||
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||||
// 是否仅允许 Claude Code 客户端
|
// 是否仅允许 Claude Code 客户端
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
// 非 Claude Code 请求降级使用的分组 ID
|
// 非 Claude Code 请求降级使用的分组 ID
|
||||||
@@ -188,7 +190,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@@ -353,6 +355,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
_m.SoraVideoPricePerRequestHd = new(float64)
|
_m.SoraVideoPricePerRequestHd = new(float64)
|
||||||
*_m.SoraVideoPricePerRequestHd = value.Float64
|
*_m.SoraVideoPricePerRequestHd = value.Float64
|
||||||
}
|
}
|
||||||
|
case group.FieldSoraStorageQuotaBytes:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.SoraStorageQuotaBytes = value.Int64
|
||||||
|
}
|
||||||
case group.FieldClaudeCodeOnly:
|
case group.FieldClaudeCodeOnly:
|
||||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
|
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
|
||||||
@@ -570,6 +578,9 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
}
|
}
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("sora_storage_quota_bytes=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("claude_code_only=")
|
builder.WriteString("claude_code_only=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
|
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ const (
|
|||||||
FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
|
FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
|
||||||
// FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
|
// FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
|
||||||
FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
|
FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
|
||||||
|
// FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
|
||||||
|
FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
|
||||||
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
|
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
|
||||||
FieldClaudeCodeOnly = "claude_code_only"
|
FieldClaudeCodeOnly = "claude_code_only"
|
||||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||||
@@ -169,6 +171,7 @@ var Columns = []string{
|
|||||||
FieldSoraImagePrice540,
|
FieldSoraImagePrice540,
|
||||||
FieldSoraVideoPricePerRequest,
|
FieldSoraVideoPricePerRequest,
|
||||||
FieldSoraVideoPricePerRequestHd,
|
FieldSoraVideoPricePerRequestHd,
|
||||||
|
FieldSoraStorageQuotaBytes,
|
||||||
FieldClaudeCodeOnly,
|
FieldClaudeCodeOnly,
|
||||||
FieldFallbackGroupID,
|
FieldFallbackGroupID,
|
||||||
FieldFallbackGroupIDOnInvalidRequest,
|
FieldFallbackGroupIDOnInvalidRequest,
|
||||||
@@ -232,6 +235,8 @@ var (
|
|||||||
SubscriptionTypeValidator func(string) error
|
SubscriptionTypeValidator func(string) error
|
||||||
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
||||||
DefaultDefaultValidityDays int
|
DefaultDefaultValidityDays int
|
||||||
|
// DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
|
||||||
|
DefaultSoraStorageQuotaBytes int64
|
||||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||||
DefaultClaudeCodeOnly bool
|
DefaultClaudeCodeOnly bool
|
||||||
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
|
||||||
@@ -357,6 +362,11 @@ func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
|
return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
|
||||||
|
func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByClaudeCodeOnly orders the results by the claude_code_only field.
|
// ByClaudeCodeOnly orders the results by the claude_code_only field.
|
||||||
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
|
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
|
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
|
||||||
|
|||||||
@@ -160,6 +160,11 @@ func SoraVideoPricePerRequestHd(v float64) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
|
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
|
||||||
|
func SoraStorageQuotaBytes(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
|
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
|
||||||
func ClaudeCodeOnly(v bool) predicate.Group {
|
func ClaudeCodeOnly(v bool) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||||
@@ -1245,6 +1250,46 @@ func SoraVideoPricePerRequestHdNotNil() predicate.Group {
|
|||||||
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
|
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesEQ(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesNEQ(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesGT(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesGTE(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesLT(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesLTE(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
|
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
|
||||||
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
|
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||||
|
|||||||
@@ -314,6 +314,20 @@ func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupC
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate {
|
||||||
|
_c.mutation.SetSoraStorageQuotaBytes(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetSoraStorageQuotaBytes(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
|
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
|
||||||
_c.mutation.SetClaudeCodeOnly(v)
|
_c.mutation.SetClaudeCodeOnly(v)
|
||||||
@@ -575,6 +589,10 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultDefaultValidityDays
|
v := group.DefaultDefaultValidityDays
|
||||||
_c.mutation.SetDefaultValidityDays(v)
|
_c.mutation.SetDefaultValidityDays(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
||||||
|
v := group.DefaultSoraStorageQuotaBytes
|
||||||
|
_c.mutation.SetSoraStorageQuotaBytes(v)
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||||
v := group.DefaultClaudeCodeOnly
|
v := group.DefaultClaudeCodeOnly
|
||||||
_c.mutation.SetClaudeCodeOnly(v)
|
_c.mutation.SetClaudeCodeOnly(v)
|
||||||
@@ -647,6 +665,9 @@ func (_c *GroupCreate) check() error {
|
|||||||
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
||||||
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
||||||
|
return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||||
}
|
}
|
||||||
@@ -773,6 +794,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
|
||||||
_node.SoraVideoPricePerRequestHd = &value
|
_node.SoraVideoPricePerRequestHd = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
_node.SoraStorageQuotaBytes = value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
|
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
|
||||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
_node.ClaudeCodeOnly = value
|
_node.ClaudeCodeOnly = value
|
||||||
@@ -1345,6 +1370,24 @@ func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert {
|
||||||
|
u.Set(group.FieldSoraStorageQuotaBytes, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldSoraStorageQuotaBytes)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert {
|
||||||
|
u.Add(group.FieldSoraStorageQuotaBytes, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
|
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
|
||||||
u.Set(group.FieldClaudeCodeOnly, v)
|
u.Set(group.FieldClaudeCodeOnly, v)
|
||||||
@@ -1970,6 +2013,27 @@ func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetSoraStorageQuotaBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddSoraStorageQuotaBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateSoraStorageQuotaBytes()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
|
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
|
||||||
return u.Update(func(s *GroupUpsert) {
|
return u.Update(func(s *GroupUpsert) {
|
||||||
@@ -2783,6 +2847,27 @@ func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetSoraStorageQuotaBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddSoraStorageQuotaBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateSoraStorageQuotaBytes()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
|
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
|
||||||
return u.Update(func(s *GroupUpsert) {
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
|||||||
@@ -463,6 +463,27 @@ func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate {
|
||||||
|
_u.mutation.ResetSoraStorageQuotaBytes()
|
||||||
|
_u.mutation.SetSoraStorageQuotaBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSoraStorageQuotaBytes(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
||||||
|
func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate {
|
||||||
|
_u.mutation.AddSoraStorageQuotaBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
|
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
|
||||||
_u.mutation.SetClaudeCodeOnly(v)
|
_u.mutation.SetClaudeCodeOnly(v)
|
||||||
@@ -1036,6 +1057,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
|
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
|
||||||
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
|
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
@@ -1825,6 +1852,27 @@ func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.ResetSoraStorageQuotaBytes()
|
||||||
|
_u.mutation.SetSoraStorageQuotaBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSoraStorageQuotaBytes(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
||||||
|
func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.AddSoraStorageQuotaBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
|
||||||
_u.mutation.SetClaudeCodeOnly(v)
|
_u.mutation.SetClaudeCodeOnly(v)
|
||||||
@@ -2428,6 +2476,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
|
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
|
||||||
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
|
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,6 +93,18 @@ func (f GroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error
|
|||||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m)
|
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary
|
||||||
|
// function as IdempotencyRecord mutator.
|
||||||
|
type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordMutation) (ent.Value, error)
|
||||||
|
|
||||||
|
// Mutate calls f(ctx, m).
|
||||||
|
func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||||
|
if mv, ok := m.(*ent.IdempotencyRecordMutation); ok {
|
||||||
|
return f(ctx, mv)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m)
|
||||||
|
}
|
||||||
|
|
||||||
// The PromoCodeFunc type is an adapter to allow the use of ordinary
|
// The PromoCodeFunc type is an adapter to allow the use of ordinary
|
||||||
// function as PromoCode mutator.
|
// function as PromoCode mutator.
|
||||||
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
|
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
|
||||||
|
|||||||
228
backend/ent/idempotencyrecord.go
Normal file
228
backend/ent/idempotencyrecord.go
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IdempotencyRecord is the model entity for the IdempotencyRecord schema.
|
||||||
|
type IdempotencyRecord struct {
|
||||||
|
config `json:"-"`
|
||||||
|
// ID of the ent.
|
||||||
|
ID int64 `json:"id,omitempty"`
|
||||||
|
// CreatedAt holds the value of the "created_at" field.
|
||||||
|
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// UpdatedAt holds the value of the "updated_at" field.
|
||||||
|
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||||
|
// Scope holds the value of the "scope" field.
|
||||||
|
Scope string `json:"scope,omitempty"`
|
||||||
|
// IdempotencyKeyHash holds the value of the "idempotency_key_hash" field.
|
||||||
|
IdempotencyKeyHash string `json:"idempotency_key_hash,omitempty"`
|
||||||
|
// RequestFingerprint holds the value of the "request_fingerprint" field.
|
||||||
|
RequestFingerprint string `json:"request_fingerprint,omitempty"`
|
||||||
|
// Status holds the value of the "status" field.
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
|
// ResponseStatus holds the value of the "response_status" field.
|
||||||
|
ResponseStatus *int `json:"response_status,omitempty"`
|
||||||
|
// ResponseBody holds the value of the "response_body" field.
|
||||||
|
ResponseBody *string `json:"response_body,omitempty"`
|
||||||
|
// ErrorReason holds the value of the "error_reason" field.
|
||||||
|
ErrorReason *string `json:"error_reason,omitempty"`
|
||||||
|
// LockedUntil holds the value of the "locked_until" field.
|
||||||
|
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
||||||
|
// ExpiresAt holds the value of the "expires_at" field.
|
||||||
|
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||||
|
selectValues sql.SelectValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanValues returns the types for scanning values from sql.Rows.
|
||||||
|
func (*IdempotencyRecord) scanValues(columns []string) ([]any, error) {
|
||||||
|
values := make([]any, len(columns))
|
||||||
|
for i := range columns {
|
||||||
|
switch columns[i] {
|
||||||
|
case idempotencyrecord.FieldID, idempotencyrecord.FieldResponseStatus:
|
||||||
|
values[i] = new(sql.NullInt64)
|
||||||
|
case idempotencyrecord.FieldScope, idempotencyrecord.FieldIdempotencyKeyHash, idempotencyrecord.FieldRequestFingerprint, idempotencyrecord.FieldStatus, idempotencyrecord.FieldResponseBody, idempotencyrecord.FieldErrorReason:
|
||||||
|
values[i] = new(sql.NullString)
|
||||||
|
case idempotencyrecord.FieldCreatedAt, idempotencyrecord.FieldUpdatedAt, idempotencyrecord.FieldLockedUntil, idempotencyrecord.FieldExpiresAt:
|
||||||
|
values[i] = new(sql.NullTime)
|
||||||
|
default:
|
||||||
|
values[i] = new(sql.UnknownType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||||
|
// to the IdempotencyRecord fields.
|
||||||
|
func (_m *IdempotencyRecord) assignValues(columns []string, values []any) error {
|
||||||
|
if m, n := len(values), len(columns); m < n {
|
||||||
|
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||||
|
}
|
||||||
|
for i := range columns {
|
||||||
|
switch columns[i] {
|
||||||
|
case idempotencyrecord.FieldID:
|
||||||
|
value, ok := values[i].(*sql.NullInt64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field id", value)
|
||||||
|
}
|
||||||
|
_m.ID = int64(value.Int64)
|
||||||
|
case idempotencyrecord.FieldCreatedAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.CreatedAt = value.Time
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldUpdatedAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.UpdatedAt = value.Time
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldScope:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field scope", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Scope = value.String
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldIdempotencyKeyHash:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field idempotency_key_hash", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.IdempotencyKeyHash = value.String
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldRequestFingerprint:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field request_fingerprint", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.RequestFingerprint = value.String
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldStatus:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field status", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.Status = value.String
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldResponseStatus:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field response_status", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.ResponseStatus = new(int)
|
||||||
|
*_m.ResponseStatus = int(value.Int64)
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldResponseBody:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field response_body", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.ResponseBody = new(string)
|
||||||
|
*_m.ResponseBody = value.String
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldErrorReason:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field error_reason", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.ErrorReason = new(string)
|
||||||
|
*_m.ErrorReason = value.String
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldLockedUntil:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field locked_until", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.LockedUntil = new(time.Time)
|
||||||
|
*_m.LockedUntil = value.Time
|
||||||
|
}
|
||||||
|
case idempotencyrecord.FieldExpiresAt:
|
||||||
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.ExpiresAt = value.Time
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the ent.Value that was dynamically selected and assigned to the IdempotencyRecord.
|
||||||
|
// This includes values selected through modifiers, order, etc.
|
||||||
|
func (_m *IdempotencyRecord) Value(name string) (ent.Value, error) {
|
||||||
|
return _m.selectValues.Get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update returns a builder for updating this IdempotencyRecord.
|
||||||
|
// Note that you need to call IdempotencyRecord.Unwrap() before calling this method if this IdempotencyRecord
|
||||||
|
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||||
|
func (_m *IdempotencyRecord) Update() *IdempotencyRecordUpdateOne {
|
||||||
|
return NewIdempotencyRecordClient(_m.config).UpdateOne(_m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap unwraps the IdempotencyRecord entity that was returned from a transaction after it was closed,
|
||||||
|
// so that all future queries will be executed through the driver which created the transaction.
|
||||||
|
func (_m *IdempotencyRecord) Unwrap() *IdempotencyRecord {
|
||||||
|
_tx, ok := _m.config.driver.(*txDriver)
|
||||||
|
if !ok {
|
||||||
|
panic("ent: IdempotencyRecord is not a transactional entity")
|
||||||
|
}
|
||||||
|
_m.config.driver = _tx.drv
|
||||||
|
return _m
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the fmt.Stringer.
|
||||||
|
func (_m *IdempotencyRecord) String() string {
|
||||||
|
var builder strings.Builder
|
||||||
|
builder.WriteString("IdempotencyRecord(")
|
||||||
|
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||||
|
builder.WriteString("created_at=")
|
||||||
|
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("updated_at=")
|
||||||
|
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("scope=")
|
||||||
|
builder.WriteString(_m.Scope)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("idempotency_key_hash=")
|
||||||
|
builder.WriteString(_m.IdempotencyKeyHash)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("request_fingerprint=")
|
||||||
|
builder.WriteString(_m.RequestFingerprint)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("status=")
|
||||||
|
builder.WriteString(_m.Status)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.ResponseStatus; v != nil {
|
||||||
|
builder.WriteString("response_status=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.ResponseBody; v != nil {
|
||||||
|
builder.WriteString("response_body=")
|
||||||
|
builder.WriteString(*v)
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.ErrorReason; v != nil {
|
||||||
|
builder.WriteString("error_reason=")
|
||||||
|
builder.WriteString(*v)
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.LockedUntil; v != nil {
|
||||||
|
builder.WriteString("locked_until=")
|
||||||
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("expires_at=")
|
||||||
|
builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
|
||||||
|
builder.WriteByte(')')
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyRecords is a parsable slice of IdempotencyRecord.
|
||||||
|
type IdempotencyRecords []*IdempotencyRecord
|
||||||
148
backend/ent/idempotencyrecord/idempotencyrecord.go
Normal file
148
backend/ent/idempotencyrecord/idempotencyrecord.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package idempotencyrecord
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Label holds the string label denoting the idempotencyrecord type in the database.
|
||||||
|
Label = "idempotency_record"
|
||||||
|
// FieldID holds the string denoting the id field in the database.
|
||||||
|
FieldID = "id"
|
||||||
|
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||||
|
FieldCreatedAt = "created_at"
|
||||||
|
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||||
|
FieldUpdatedAt = "updated_at"
|
||||||
|
// FieldScope holds the string denoting the scope field in the database.
|
||||||
|
FieldScope = "scope"
|
||||||
|
// FieldIdempotencyKeyHash holds the string denoting the idempotency_key_hash field in the database.
|
||||||
|
FieldIdempotencyKeyHash = "idempotency_key_hash"
|
||||||
|
// FieldRequestFingerprint holds the string denoting the request_fingerprint field in the database.
|
||||||
|
FieldRequestFingerprint = "request_fingerprint"
|
||||||
|
// FieldStatus holds the string denoting the status field in the database.
|
||||||
|
FieldStatus = "status"
|
||||||
|
// FieldResponseStatus holds the string denoting the response_status field in the database.
|
||||||
|
FieldResponseStatus = "response_status"
|
||||||
|
// FieldResponseBody holds the string denoting the response_body field in the database.
|
||||||
|
FieldResponseBody = "response_body"
|
||||||
|
// FieldErrorReason holds the string denoting the error_reason field in the database.
|
||||||
|
FieldErrorReason = "error_reason"
|
||||||
|
// FieldLockedUntil holds the string denoting the locked_until field in the database.
|
||||||
|
FieldLockedUntil = "locked_until"
|
||||||
|
// FieldExpiresAt holds the string denoting the expires_at field in the database.
|
||||||
|
FieldExpiresAt = "expires_at"
|
||||||
|
// Table holds the table name of the idempotencyrecord in the database.
|
||||||
|
Table = "idempotency_records"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Columns holds all SQL columns for idempotencyrecord fields.
|
||||||
|
var Columns = []string{
|
||||||
|
FieldID,
|
||||||
|
FieldCreatedAt,
|
||||||
|
FieldUpdatedAt,
|
||||||
|
FieldScope,
|
||||||
|
FieldIdempotencyKeyHash,
|
||||||
|
FieldRequestFingerprint,
|
||||||
|
FieldStatus,
|
||||||
|
FieldResponseStatus,
|
||||||
|
FieldResponseBody,
|
||||||
|
FieldErrorReason,
|
||||||
|
FieldLockedUntil,
|
||||||
|
FieldExpiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||||
|
func ValidColumn(column string) bool {
|
||||||
|
for i := range Columns {
|
||||||
|
if column == Columns[i] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||||
|
DefaultCreatedAt func() time.Time
|
||||||
|
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||||
|
DefaultUpdatedAt func() time.Time
|
||||||
|
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||||
|
UpdateDefaultUpdatedAt func() time.Time
|
||||||
|
// ScopeValidator is a validator for the "scope" field. It is called by the builders before save.
|
||||||
|
ScopeValidator func(string) error
|
||||||
|
// IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save.
|
||||||
|
IdempotencyKeyHashValidator func(string) error
|
||||||
|
// RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save.
|
||||||
|
RequestFingerprintValidator func(string) error
|
||||||
|
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||||
|
StatusValidator func(string) error
|
||||||
|
// ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
|
||||||
|
ErrorReasonValidator func(string) error
|
||||||
|
)
|
||||||
|
|
||||||
|
// OrderOption defines the ordering options for the IdempotencyRecord queries.
|
||||||
|
type OrderOption func(*sql.Selector)
|
||||||
|
|
||||||
|
// ByID orders the results by the id field.
|
||||||
|
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByCreatedAt orders the results by the created_at field.
|
||||||
|
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByUpdatedAt orders the results by the updated_at field.
|
||||||
|
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByScope orders the results by the scope field.
|
||||||
|
func ByScope(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldScope, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByIdempotencyKeyHash orders the results by the idempotency_key_hash field.
|
||||||
|
func ByIdempotencyKeyHash(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldIdempotencyKeyHash, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByRequestFingerprint orders the results by the request_fingerprint field.
|
||||||
|
func ByRequestFingerprint(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldRequestFingerprint, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByStatus orders the results by the status field.
|
||||||
|
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldStatus, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByResponseStatus orders the results by the response_status field.
|
||||||
|
func ByResponseStatus(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldResponseStatus, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByResponseBody orders the results by the response_body field.
|
||||||
|
func ByResponseBody(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldResponseBody, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByErrorReason orders the results by the error_reason field.
|
||||||
|
func ByErrorReason(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldErrorReason, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByLockedUntil orders the results by the locked_until field.
|
||||||
|
func ByLockedUntil(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldLockedUntil, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByExpiresAt orders the results by the expires_at field.
|
||||||
|
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
|
||||||
|
}
|
||||||
755
backend/ent/idempotencyrecord/where.go
Normal file
755
backend/ent/idempotencyrecord/where.go
Normal file
@@ -0,0 +1,755 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package idempotencyrecord
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ID filters vertices based on their ID field.
|
||||||
|
func ID(id int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDEQ applies the EQ predicate on the ID field.
|
||||||
|
func IDEQ(id int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDNEQ applies the NEQ predicate on the ID field.
|
||||||
|
func IDNEQ(id int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDIn applies the In predicate on the ID field.
|
||||||
|
func IDIn(ids ...int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldID, ids...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDNotIn applies the NotIn predicate on the ID field.
|
||||||
|
func IDNotIn(ids ...int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldID, ids...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDGT applies the GT predicate on the ID field.
|
||||||
|
func IDGT(id int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDGTE applies the GTE predicate on the ID field.
|
||||||
|
func IDGTE(id int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDLT applies the LT predicate on the ID field.
|
||||||
|
func IDLT(id int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDLTE applies the LTE predicate on the ID field.
|
||||||
|
func IDLTE(id int64) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldID, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||||
|
func CreatedAt(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||||
|
func UpdatedAt(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ.
|
||||||
|
func Scope(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHash applies equality check predicate on the "idempotency_key_hash" field. It's identical to IdempotencyKeyHashEQ.
|
||||||
|
func IdempotencyKeyHash(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprint applies equality check predicate on the "request_fingerprint" field. It's identical to RequestFingerprintEQ.
|
||||||
|
func RequestFingerprint(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
|
||||||
|
func Status(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatus applies equality check predicate on the "response_status" field. It's identical to ResponseStatusEQ.
|
||||||
|
func ResponseStatus(v int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBody applies equality check predicate on the "response_body" field. It's identical to ResponseBodyEQ.
|
||||||
|
func ResponseBody(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReason applies equality check predicate on the "error_reason" field. It's identical to ErrorReasonEQ.
|
||||||
|
func ErrorReason(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntil applies equality check predicate on the "locked_until" field. It's identical to LockedUntilEQ.
|
||||||
|
func LockedUntil(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
|
||||||
|
func ExpiresAt(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
|
func CreatedAtEQ(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||||
|
func CreatedAtNEQ(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||||
|
func CreatedAtIn(vs ...time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldCreatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||||
|
func CreatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||||
|
func CreatedAtGT(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||||
|
func CreatedAtGTE(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||||
|
func CreatedAtLT(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||||
|
func CreatedAtLTE(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldCreatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtEQ(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtNEQ(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtIn(vs ...time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtGT(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtGTE(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtLT(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||||
|
func UpdatedAtLTE(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldUpdatedAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeEQ applies the EQ predicate on the "scope" field.
|
||||||
|
func ScopeEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeNEQ applies the NEQ predicate on the "scope" field.
|
||||||
|
func ScopeNEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeIn applies the In predicate on the "scope" field.
|
||||||
|
func ScopeIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldScope, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeNotIn applies the NotIn predicate on the "scope" field.
|
||||||
|
func ScopeNotIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldScope, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeGT applies the GT predicate on the "scope" field.
|
||||||
|
func ScopeGT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeGTE applies the GTE predicate on the "scope" field.
|
||||||
|
func ScopeGTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeLT applies the LT predicate on the "scope" field.
|
||||||
|
func ScopeLT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeLTE applies the LTE predicate on the "scope" field.
|
||||||
|
func ScopeLTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeContains applies the Contains predicate on the "scope" field.
|
||||||
|
func ScopeContains(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContains(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeHasPrefix applies the HasPrefix predicate on the "scope" field.
|
||||||
|
func ScopeHasPrefix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeHasSuffix applies the HasSuffix predicate on the "scope" field.
|
||||||
|
func ScopeHasSuffix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeEqualFold applies the EqualFold predicate on the "scope" field.
|
||||||
|
func ScopeEqualFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScopeContainsFold applies the ContainsFold predicate on the "scope" field.
|
||||||
|
func ScopeContainsFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldScope, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashEQ applies the EQ predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashNEQ applies the NEQ predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashNEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashIn applies the In predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldIdempotencyKeyHash, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashNotIn applies the NotIn predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashNotIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldIdempotencyKeyHash, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashGT applies the GT predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashGT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashGTE applies the GTE predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashGTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashLT applies the LT predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashLT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashLTE applies the LTE predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashLTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashContains applies the Contains predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashContains(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContains(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashHasPrefix applies the HasPrefix predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashHasPrefix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashHasSuffix applies the HasSuffix predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashHasSuffix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashEqualFold applies the EqualFold predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashEqualFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyKeyHashContainsFold applies the ContainsFold predicate on the "idempotency_key_hash" field.
|
||||||
|
func IdempotencyKeyHashContainsFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldIdempotencyKeyHash, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintEQ applies the EQ predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintNEQ applies the NEQ predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintNEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintIn applies the In predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldRequestFingerprint, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintNotIn applies the NotIn predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintNotIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldRequestFingerprint, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintGT applies the GT predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintGT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintGTE applies the GTE predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintGTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintLT applies the LT predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintLT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintLTE applies the LTE predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintLTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintContains applies the Contains predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintContains(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContains(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintHasPrefix applies the HasPrefix predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintHasPrefix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintHasSuffix applies the HasSuffix predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintHasSuffix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintEqualFold applies the EqualFold predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintEqualFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestFingerprintContainsFold applies the ContainsFold predicate on the "request_fingerprint" field.
|
||||||
|
func RequestFingerprintContainsFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldRequestFingerprint, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusEQ applies the EQ predicate on the "status" field.
|
||||||
|
func StatusEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusNEQ applies the NEQ predicate on the "status" field.
|
||||||
|
func StatusNEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusIn applies the In predicate on the "status" field.
|
||||||
|
func StatusIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldStatus, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusNotIn applies the NotIn predicate on the "status" field.
|
||||||
|
func StatusNotIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldStatus, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusGT applies the GT predicate on the "status" field.
|
||||||
|
func StatusGT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusGTE applies the GTE predicate on the "status" field.
|
||||||
|
func StatusGTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusLT applies the LT predicate on the "status" field.
|
||||||
|
func StatusLT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusLTE applies the LTE predicate on the "status" field.
|
||||||
|
func StatusLTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusContains applies the Contains predicate on the "status" field.
|
||||||
|
func StatusContains(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContains(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
|
||||||
|
func StatusHasPrefix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
|
||||||
|
func StatusHasSuffix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusEqualFold applies the EqualFold predicate on the "status" field.
|
||||||
|
func StatusEqualFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusContainsFold applies the ContainsFold predicate on the "status" field.
|
||||||
|
func StatusContainsFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusEQ applies the EQ predicate on the "response_status" field.
|
||||||
|
func ResponseStatusEQ(v int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusNEQ applies the NEQ predicate on the "response_status" field.
|
||||||
|
func ResponseStatusNEQ(v int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusIn applies the In predicate on the "response_status" field.
|
||||||
|
func ResponseStatusIn(vs ...int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseStatus, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusNotIn applies the NotIn predicate on the "response_status" field.
|
||||||
|
func ResponseStatusNotIn(vs ...int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseStatus, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusGT applies the GT predicate on the "response_status" field.
|
||||||
|
func ResponseStatusGT(v int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusGTE applies the GTE predicate on the "response_status" field.
|
||||||
|
func ResponseStatusGTE(v int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusLT applies the LT predicate on the "response_status" field.
|
||||||
|
func ResponseStatusLT(v int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusLTE applies the LTE predicate on the "response_status" field.
|
||||||
|
func ResponseStatusLTE(v int) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseStatus, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusIsNil applies the IsNil predicate on the "response_status" field.
|
||||||
|
func ResponseStatusIsNil() predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseStatus))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseStatusNotNil applies the NotNil predicate on the "response_status" field.
|
||||||
|
func ResponseStatusNotNil() predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseStatus))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyEQ applies the EQ predicate on the "response_body" field.
|
||||||
|
func ResponseBodyEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyNEQ applies the NEQ predicate on the "response_body" field.
|
||||||
|
func ResponseBodyNEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyIn applies the In predicate on the "response_body" field.
|
||||||
|
func ResponseBodyIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseBody, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyNotIn applies the NotIn predicate on the "response_body" field.
|
||||||
|
func ResponseBodyNotIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseBody, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyGT applies the GT predicate on the "response_body" field.
|
||||||
|
func ResponseBodyGT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyGTE applies the GTE predicate on the "response_body" field.
|
||||||
|
func ResponseBodyGTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyLT applies the LT predicate on the "response_body" field.
|
||||||
|
func ResponseBodyLT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyLTE applies the LTE predicate on the "response_body" field.
|
||||||
|
func ResponseBodyLTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyContains applies the Contains predicate on the "response_body" field.
|
||||||
|
func ResponseBodyContains(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContains(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyHasPrefix applies the HasPrefix predicate on the "response_body" field.
|
||||||
|
func ResponseBodyHasPrefix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyHasSuffix applies the HasSuffix predicate on the "response_body" field.
|
||||||
|
func ResponseBodyHasSuffix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyIsNil applies the IsNil predicate on the "response_body" field.
|
||||||
|
func ResponseBodyIsNil() predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyNotNil applies the NotNil predicate on the "response_body" field.
|
||||||
|
func ResponseBodyNotNil() predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyEqualFold applies the EqualFold predicate on the "response_body" field.
|
||||||
|
func ResponseBodyEqualFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseBodyContainsFold applies the ContainsFold predicate on the "response_body" field.
|
||||||
|
func ResponseBodyContainsFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldResponseBody, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonEQ applies the EQ predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonNEQ applies the NEQ predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonNEQ(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonIn applies the In predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldErrorReason, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonNotIn applies the NotIn predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonNotIn(vs ...string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldErrorReason, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonGT applies the GT predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonGT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonGTE applies the GTE predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonGTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonLT applies the LT predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonLT(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonLTE applies the LTE predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonLTE(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonContains applies the Contains predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonContains(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContains(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonHasPrefix applies the HasPrefix predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonHasPrefix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonHasSuffix applies the HasSuffix predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonHasSuffix(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonIsNil applies the IsNil predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonIsNil() predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIsNull(FieldErrorReason))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonNotNil applies the NotNil predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonNotNil() predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotNull(FieldErrorReason))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonEqualFold applies the EqualFold predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonEqualFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorReasonContainsFold applies the ContainsFold predicate on the "error_reason" field.
|
||||||
|
func ErrorReasonContainsFold(v string) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldErrorReason, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilEQ applies the EQ predicate on the "locked_until" field.
|
||||||
|
func LockedUntilEQ(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilNEQ applies the NEQ predicate on the "locked_until" field.
|
||||||
|
func LockedUntilNEQ(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldLockedUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilIn applies the In predicate on the "locked_until" field.
|
||||||
|
func LockedUntilIn(vs ...time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldLockedUntil, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilNotIn applies the NotIn predicate on the "locked_until" field.
|
||||||
|
func LockedUntilNotIn(vs ...time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldLockedUntil, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilGT applies the GT predicate on the "locked_until" field.
|
||||||
|
func LockedUntilGT(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldLockedUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilGTE applies the GTE predicate on the "locked_until" field.
|
||||||
|
func LockedUntilGTE(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldLockedUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilLT applies the LT predicate on the "locked_until" field.
|
||||||
|
func LockedUntilLT(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldLockedUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilLTE applies the LTE predicate on the "locked_until" field.
|
||||||
|
func LockedUntilLTE(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldLockedUntil, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilIsNil applies the IsNil predicate on the "locked_until" field.
|
||||||
|
func LockedUntilIsNil() predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIsNull(FieldLockedUntil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LockedUntilNotNil applies the NotNil predicate on the "locked_until" field.
|
||||||
|
func LockedUntilNotNil() predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotNull(FieldLockedUntil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtEQ(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtNEQ(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNEQ(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtIn applies the In predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtIn(vs ...time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldIn(FieldExpiresAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtNotIn(vs ...time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldNotIn(FieldExpiresAt, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtGT(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGT(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtGTE(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldGTE(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtLT(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLT(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
|
||||||
|
func ExpiresAtLTE(v time.Time) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.FieldLTE(FieldExpiresAt, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// And groups predicates with the AND operator between them.
|
||||||
|
func And(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.AndPredicates(predicates...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or groups predicates with the OR operator between them.
|
||||||
|
func Or(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.OrPredicates(predicates...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not applies the not operator on the given predicate.
|
||||||
|
func Not(p predicate.IdempotencyRecord) predicate.IdempotencyRecord {
|
||||||
|
return predicate.IdempotencyRecord(sql.NotPredicates(p))
|
||||||
|
}
|
||||||
1132
backend/ent/idempotencyrecord_create.go
Normal file
1132
backend/ent/idempotencyrecord_create.go
Normal file
File diff suppressed because it is too large
Load Diff
88
backend/ent/idempotencyrecord_delete.go
Normal file
88
backend/ent/idempotencyrecord_delete.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IdempotencyRecordDelete is the builder for deleting a IdempotencyRecord entity.
|
||||||
|
type IdempotencyRecordDelete struct {
|
||||||
|
config
|
||||||
|
hooks []Hook
|
||||||
|
mutation *IdempotencyRecordMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the IdempotencyRecordDelete builder.
|
||||||
|
func (_d *IdempotencyRecordDelete) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDelete {
|
||||||
|
_d.mutation.Where(ps...)
|
||||||
|
return _d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||||
|
func (_d *IdempotencyRecordDelete) Exec(ctx context.Context) (int, error) {
|
||||||
|
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_d *IdempotencyRecordDelete) ExecX(ctx context.Context) int {
|
||||||
|
n, err := _d.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_d *IdempotencyRecordDelete) sqlExec(ctx context.Context) (int, error) {
|
||||||
|
_spec := sqlgraph.NewDeleteSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
|
||||||
|
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||||
|
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
_d.mutation.done = true
|
||||||
|
return affected, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyRecordDeleteOne is the builder for deleting a single IdempotencyRecord entity.
|
||||||
|
type IdempotencyRecordDeleteOne struct {
|
||||||
|
_d *IdempotencyRecordDelete
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the IdempotencyRecordDelete builder.
|
||||||
|
func (_d *IdempotencyRecordDeleteOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDeleteOne {
|
||||||
|
_d._d.mutation.Where(ps...)
|
||||||
|
return _d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the deletion query.
|
||||||
|
func (_d *IdempotencyRecordDeleteOne) Exec(ctx context.Context) error {
|
||||||
|
n, err := _d._d.Exec(ctx)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
|
case n == 0:
|
||||||
|
return &NotFoundError{idempotencyrecord.Label}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_d *IdempotencyRecordDeleteOne) ExecX(ctx context.Context) {
|
||||||
|
if err := _d.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
564
backend/ent/idempotencyrecord_query.go
Normal file
564
backend/ent/idempotencyrecord_query.go
Normal file
@@ -0,0 +1,564 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"entgo.io/ent"
|
||||||
|
"entgo.io/ent/dialect"
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IdempotencyRecordQuery is the builder for querying IdempotencyRecord entities.
|
||||||
|
type IdempotencyRecordQuery struct {
|
||||||
|
config
|
||||||
|
ctx *QueryContext
|
||||||
|
order []idempotencyrecord.OrderOption
|
||||||
|
inters []Interceptor
|
||||||
|
predicates []predicate.IdempotencyRecord
|
||||||
|
modifiers []func(*sql.Selector)
|
||||||
|
// intermediate query (i.e. traversal path).
|
||||||
|
sql *sql.Selector
|
||||||
|
path func(context.Context) (*sql.Selector, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where adds a new predicate for the IdempotencyRecordQuery builder.
|
||||||
|
func (_q *IdempotencyRecordQuery) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordQuery {
|
||||||
|
_q.predicates = append(_q.predicates, ps...)
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit the number of records to be returned by this query.
|
||||||
|
func (_q *IdempotencyRecordQuery) Limit(limit int) *IdempotencyRecordQuery {
|
||||||
|
_q.ctx.Limit = &limit
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset to start from.
|
||||||
|
func (_q *IdempotencyRecordQuery) Offset(offset int) *IdempotencyRecordQuery {
|
||||||
|
_q.ctx.Offset = &offset
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique configures the query builder to filter duplicate records on query.
|
||||||
|
// By default, unique is set to true, and can be disabled using this method.
|
||||||
|
func (_q *IdempotencyRecordQuery) Unique(unique bool) *IdempotencyRecordQuery {
|
||||||
|
_q.ctx.Unique = &unique
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Order specifies how the records should be ordered.
|
||||||
|
func (_q *IdempotencyRecordQuery) Order(o ...idempotencyrecord.OrderOption) *IdempotencyRecordQuery {
|
||||||
|
_q.order = append(_q.order, o...)
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// First returns the first IdempotencyRecord entity from the query.
|
||||||
|
// Returns a *NotFoundError when no IdempotencyRecord was found.
|
||||||
|
func (_q *IdempotencyRecordQuery) First(ctx context.Context) (*IdempotencyRecord, error) {
|
||||||
|
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return nil, &NotFoundError{idempotencyrecord.Label}
|
||||||
|
}
|
||||||
|
return nodes[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstX is like First, but panics if an error occurs.
|
||||||
|
func (_q *IdempotencyRecordQuery) FirstX(ctx context.Context) *IdempotencyRecord {
|
||||||
|
node, err := _q.First(ctx)
|
||||||
|
if err != nil && !IsNotFound(err) {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstID returns the first IdempotencyRecord ID from the query.
|
||||||
|
// Returns a *NotFoundError when no IdempotencyRecord ID was found.
|
||||||
|
func (_q *IdempotencyRecordQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||||
|
var ids []int64
|
||||||
|
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
err = &NotFoundError{idempotencyrecord.Label}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return ids[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||||
|
func (_q *IdempotencyRecordQuery) FirstIDX(ctx context.Context) int64 {
|
||||||
|
id, err := _q.FirstID(ctx)
|
||||||
|
if err != nil && !IsNotFound(err) {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only returns a single IdempotencyRecord entity found by the query, ensuring it only returns one.
|
||||||
|
// Returns a *NotSingularError when more than one IdempotencyRecord entity is found.
|
||||||
|
// Returns a *NotFoundError when no IdempotencyRecord entities are found.
|
||||||
|
func (_q *IdempotencyRecordQuery) Only(ctx context.Context) (*IdempotencyRecord, error) {
|
||||||
|
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch len(nodes) {
|
||||||
|
case 1:
|
||||||
|
return nodes[0], nil
|
||||||
|
case 0:
|
||||||
|
return nil, &NotFoundError{idempotencyrecord.Label}
|
||||||
|
default:
|
||||||
|
return nil, &NotSingularError{idempotencyrecord.Label}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyX is like Only, but panics if an error occurs.
|
||||||
|
func (_q *IdempotencyRecordQuery) OnlyX(ctx context.Context) *IdempotencyRecord {
|
||||||
|
node, err := _q.Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyID is like Only, but returns the only IdempotencyRecord ID in the query.
|
||||||
|
// Returns a *NotSingularError when more than one IdempotencyRecord ID is found.
|
||||||
|
// Returns a *NotFoundError when no entities are found.
|
||||||
|
func (_q *IdempotencyRecordQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||||
|
var ids []int64
|
||||||
|
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch len(ids) {
|
||||||
|
case 1:
|
||||||
|
id = ids[0]
|
||||||
|
case 0:
|
||||||
|
err = &NotFoundError{idempotencyrecord.Label}
|
||||||
|
default:
|
||||||
|
err = &NotSingularError{idempotencyrecord.Label}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||||
|
func (_q *IdempotencyRecordQuery) OnlyIDX(ctx context.Context) int64 {
|
||||||
|
id, err := _q.OnlyID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// All executes the query and returns a list of IdempotencyRecords.
|
||||||
|
func (_q *IdempotencyRecordQuery) All(ctx context.Context) ([]*IdempotencyRecord, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
qr := querierAll[[]*IdempotencyRecord, *IdempotencyRecordQuery]()
|
||||||
|
return withInterceptors[[]*IdempotencyRecord](ctx, _q, qr, _q.inters)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllX is like All, but panics if an error occurs.
|
||||||
|
func (_q *IdempotencyRecordQuery) AllX(ctx context.Context) []*IdempotencyRecord {
|
||||||
|
nodes, err := _q.All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDs executes the query and returns a list of IdempotencyRecord IDs.
|
||||||
|
func (_q *IdempotencyRecordQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||||
|
if _q.ctx.Unique == nil && _q.path != nil {
|
||||||
|
_q.Unique(true)
|
||||||
|
}
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||||
|
if err = _q.Select(idempotencyrecord.FieldID).Scan(ctx, &ids); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDsX is like IDs, but panics if an error occurs.
|
||||||
|
func (_q *IdempotencyRecordQuery) IDsX(ctx context.Context) []int64 {
|
||||||
|
ids, err := _q.IDs(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the count of the given query.
|
||||||
|
func (_q *IdempotencyRecordQuery) Count(ctx context.Context) (int, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||||
|
if err := _q.prepareQuery(ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return withInterceptors[int](ctx, _q, querierCount[*IdempotencyRecordQuery](), _q.inters)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountX is like Count, but panics if an error occurs.
|
||||||
|
func (_q *IdempotencyRecordQuery) CountX(ctx context.Context) int {
|
||||||
|
count, err := _q.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exist returns true if the query has elements in the graph.
|
||||||
|
func (_q *IdempotencyRecordQuery) Exist(ctx context.Context) (bool, error) {
|
||||||
|
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||||
|
switch _, err := _q.FirstID(ctx); {
|
||||||
|
case IsNotFound(err):
|
||||||
|
return false, nil
|
||||||
|
case err != nil:
|
||||||
|
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||||
|
default:
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExistX is like Exist, but panics if an error occurs.
|
||||||
|
func (_q *IdempotencyRecordQuery) ExistX(ctx context.Context) bool {
|
||||||
|
exist, err := _q.Exist(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return exist
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a duplicate of the IdempotencyRecordQuery builder, including all associated steps. It can be
|
||||||
|
// used to prepare common query builders and use them differently after the clone is made.
|
||||||
|
func (_q *IdempotencyRecordQuery) Clone() *IdempotencyRecordQuery {
|
||||||
|
if _q == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &IdempotencyRecordQuery{
|
||||||
|
config: _q.config,
|
||||||
|
ctx: _q.ctx.Clone(),
|
||||||
|
order: append([]idempotencyrecord.OrderOption{}, _q.order...),
|
||||||
|
inters: append([]Interceptor{}, _q.inters...),
|
||||||
|
predicates: append([]predicate.IdempotencyRecord{}, _q.predicates...),
|
||||||
|
// clone intermediate query.
|
||||||
|
sql: _q.sql.Clone(),
|
||||||
|
path: _q.path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupBy is used to group vertices by one or more fields/columns.
|
||||||
|
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// var v []struct {
|
||||||
|
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// Count int `json:"count,omitempty"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// client.IdempotencyRecord.Query().
|
||||||
|
// GroupBy(idempotencyrecord.FieldCreatedAt).
|
||||||
|
// Aggregate(ent.Count()).
|
||||||
|
// Scan(ctx, &v)
|
||||||
|
func (_q *IdempotencyRecordQuery) GroupBy(field string, fields ...string) *IdempotencyRecordGroupBy {
|
||||||
|
_q.ctx.Fields = append([]string{field}, fields...)
|
||||||
|
grbuild := &IdempotencyRecordGroupBy{build: _q}
|
||||||
|
grbuild.flds = &_q.ctx.Fields
|
||||||
|
grbuild.label = idempotencyrecord.Label
|
||||||
|
grbuild.scan = grbuild.Scan
|
||||||
|
return grbuild
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select allows the selection one or more fields/columns for the given query,
|
||||||
|
// instead of selecting all fields in the entity.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// var v []struct {
|
||||||
|
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// client.IdempotencyRecord.Query().
|
||||||
|
// Select(idempotencyrecord.FieldCreatedAt).
|
||||||
|
// Scan(ctx, &v)
|
||||||
|
func (_q *IdempotencyRecordQuery) Select(fields ...string) *IdempotencyRecordSelect {
|
||||||
|
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||||
|
sbuild := &IdempotencyRecordSelect{IdempotencyRecordQuery: _q}
|
||||||
|
sbuild.label = idempotencyrecord.Label
|
||||||
|
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||||
|
return sbuild
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate returns a IdempotencyRecordSelect configured with the given aggregations.
|
||||||
|
func (_q *IdempotencyRecordQuery) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect {
|
||||||
|
return _q.Select().Aggregate(fns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *IdempotencyRecordQuery) prepareQuery(ctx context.Context) error {
|
||||||
|
for _, inter := range _q.inters {
|
||||||
|
if inter == nil {
|
||||||
|
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||||
|
}
|
||||||
|
if trv, ok := inter.(Traverser); ok {
|
||||||
|
if err := trv.Traverse(ctx, _q); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, f := range _q.ctx.Fields {
|
||||||
|
if !idempotencyrecord.ValidColumn(f) {
|
||||||
|
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _q.path != nil {
|
||||||
|
prev, err := _q.path(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_q.sql = prev
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *IdempotencyRecordQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdempotencyRecord, error) {
|
||||||
|
var (
|
||||||
|
nodes = []*IdempotencyRecord{}
|
||||||
|
_spec = _q.querySpec()
|
||||||
|
)
|
||||||
|
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||||
|
return (*IdempotencyRecord).scanValues(nil, columns)
|
||||||
|
}
|
||||||
|
_spec.Assign = func(columns []string, values []any) error {
|
||||||
|
node := &IdempotencyRecord{config: _q.config}
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
return node.assignValues(columns, values)
|
||||||
|
}
|
||||||
|
if len(_q.modifiers) > 0 {
|
||||||
|
_spec.Modifiers = _q.modifiers
|
||||||
|
}
|
||||||
|
for i := range hooks {
|
||||||
|
hooks[i](ctx, _spec)
|
||||||
|
}
|
||||||
|
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *IdempotencyRecordQuery) sqlCount(ctx context.Context) (int, error) {
|
||||||
|
_spec := _q.querySpec()
|
||||||
|
if len(_q.modifiers) > 0 {
|
||||||
|
_spec.Modifiers = _q.modifiers
|
||||||
|
}
|
||||||
|
_spec.Node.Columns = _q.ctx.Fields
|
||||||
|
if len(_q.ctx.Fields) > 0 {
|
||||||
|
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||||
|
}
|
||||||
|
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *IdempotencyRecordQuery) querySpec() *sqlgraph.QuerySpec {
|
||||||
|
_spec := sqlgraph.NewQuerySpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
|
||||||
|
_spec.From = _q.sql
|
||||||
|
if unique := _q.ctx.Unique; unique != nil {
|
||||||
|
_spec.Unique = *unique
|
||||||
|
} else if _q.path != nil {
|
||||||
|
_spec.Unique = true
|
||||||
|
}
|
||||||
|
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||||
|
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID)
|
||||||
|
for i := range fields {
|
||||||
|
if fields[i] != idempotencyrecord.FieldID {
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ps := _q.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if limit := _q.ctx.Limit; limit != nil {
|
||||||
|
_spec.Limit = *limit
|
||||||
|
}
|
||||||
|
if offset := _q.ctx.Offset; offset != nil {
|
||||||
|
_spec.Offset = *offset
|
||||||
|
}
|
||||||
|
if ps := _q.order; len(ps) > 0 {
|
||||||
|
_spec.Order = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return _spec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_q *IdempotencyRecordQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||||
|
builder := sql.Dialect(_q.driver.Dialect())
|
||||||
|
t1 := builder.Table(idempotencyrecord.Table)
|
||||||
|
columns := _q.ctx.Fields
|
||||||
|
if len(columns) == 0 {
|
||||||
|
columns = idempotencyrecord.Columns
|
||||||
|
}
|
||||||
|
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||||
|
if _q.sql != nil {
|
||||||
|
selector = _q.sql
|
||||||
|
selector.Select(selector.Columns(columns...)...)
|
||||||
|
}
|
||||||
|
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||||
|
selector.Distinct()
|
||||||
|
}
|
||||||
|
for _, m := range _q.modifiers {
|
||||||
|
m(selector)
|
||||||
|
}
|
||||||
|
for _, p := range _q.predicates {
|
||||||
|
p(selector)
|
||||||
|
}
|
||||||
|
for _, p := range _q.order {
|
||||||
|
p(selector)
|
||||||
|
}
|
||||||
|
if offset := _q.ctx.Offset; offset != nil {
|
||||||
|
// limit is mandatory for offset clause. We start
|
||||||
|
// with default value, and override it below if needed.
|
||||||
|
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||||
|
}
|
||||||
|
if limit := _q.ctx.Limit; limit != nil {
|
||||||
|
selector.Limit(*limit)
|
||||||
|
}
|
||||||
|
return selector
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||||
|
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||||
|
// either committed or rolled-back.
|
||||||
|
func (_q *IdempotencyRecordQuery) ForUpdate(opts ...sql.LockOption) *IdempotencyRecordQuery {
|
||||||
|
if _q.driver.Dialect() == dialect.Postgres {
|
||||||
|
_q.Unique(false)
|
||||||
|
}
|
||||||
|
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||||
|
s.ForUpdate(opts...)
|
||||||
|
})
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||||
|
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||||
|
// until your transaction commits.
|
||||||
|
func (_q *IdempotencyRecordQuery) ForShare(opts ...sql.LockOption) *IdempotencyRecordQuery {
|
||||||
|
if _q.driver.Dialect() == dialect.Postgres {
|
||||||
|
_q.Unique(false)
|
||||||
|
}
|
||||||
|
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||||
|
s.ForShare(opts...)
|
||||||
|
})
|
||||||
|
return _q
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyRecordGroupBy is the group-by builder for IdempotencyRecord entities.
|
||||||
|
type IdempotencyRecordGroupBy struct {
|
||||||
|
selector
|
||||||
|
build *IdempotencyRecordQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate adds the given aggregation functions to the group-by query.
|
||||||
|
func (_g *IdempotencyRecordGroupBy) Aggregate(fns ...AggregateFunc) *IdempotencyRecordGroupBy {
|
||||||
|
_g.fns = append(_g.fns, fns...)
|
||||||
|
return _g
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan applies the selector query and scans the result into the given value.
|
||||||
|
func (_g *IdempotencyRecordGroupBy) Scan(ctx context.Context, v any) error {
|
||||||
|
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||||
|
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_g *IdempotencyRecordGroupBy) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error {
|
||||||
|
selector := root.sqlQuery(ctx).Select()
|
||||||
|
aggregation := make([]string, 0, len(_g.fns))
|
||||||
|
for _, fn := range _g.fns {
|
||||||
|
aggregation = append(aggregation, fn(selector))
|
||||||
|
}
|
||||||
|
if len(selector.SelectedColumns()) == 0 {
|
||||||
|
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||||
|
for _, f := range *_g.flds {
|
||||||
|
columns = append(columns, selector.C(f))
|
||||||
|
}
|
||||||
|
columns = append(columns, aggregation...)
|
||||||
|
selector.Select(columns...)
|
||||||
|
}
|
||||||
|
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||||
|
if err := selector.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rows := &sql.Rows{}
|
||||||
|
query, args := selector.Query()
|
||||||
|
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return sql.ScanSlice(rows, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyRecordSelect is the builder for selecting fields of IdempotencyRecord entities.
|
||||||
|
type IdempotencyRecordSelect struct {
|
||||||
|
*IdempotencyRecordQuery
|
||||||
|
selector
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate adds the given aggregation functions to the selector query.
|
||||||
|
func (_s *IdempotencyRecordSelect) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect {
|
||||||
|
_s.fns = append(_s.fns, fns...)
|
||||||
|
return _s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan applies the selector query and scans the result into the given value.
|
||||||
|
func (_s *IdempotencyRecordSelect) Scan(ctx context.Context, v any) error {
|
||||||
|
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||||
|
if err := _s.prepareQuery(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordSelect](ctx, _s.IdempotencyRecordQuery, _s, _s.inters, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_s *IdempotencyRecordSelect) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error {
|
||||||
|
selector := root.sqlQuery(ctx)
|
||||||
|
aggregation := make([]string, 0, len(_s.fns))
|
||||||
|
for _, fn := range _s.fns {
|
||||||
|
aggregation = append(aggregation, fn(selector))
|
||||||
|
}
|
||||||
|
switch n := len(*_s.selector.flds); {
|
||||||
|
case n == 0 && len(aggregation) > 0:
|
||||||
|
selector.Select(aggregation...)
|
||||||
|
case n != 0 && len(aggregation) > 0:
|
||||||
|
selector.AppendSelect(aggregation...)
|
||||||
|
}
|
||||||
|
rows := &sql.Rows{}
|
||||||
|
query, args := selector.Query()
|
||||||
|
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return sql.ScanSlice(rows, v)
|
||||||
|
}
|
||||||
676
backend/ent/idempotencyrecord_update.go
Normal file
676
backend/ent/idempotencyrecord_update.go
Normal file
@@ -0,0 +1,676 @@
|
|||||||
|
// Code generated by ent, DO NOT EDIT.
|
||||||
|
|
||||||
|
package ent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect/sql"
|
||||||
|
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||||
|
"entgo.io/ent/schema/field"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IdempotencyRecordUpdate is the builder for updating IdempotencyRecord entities.
|
||||||
|
type IdempotencyRecordUpdate struct {
|
||||||
|
config
|
||||||
|
hooks []Hook
|
||||||
|
mutation *IdempotencyRecordMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the IdempotencyRecordUpdate builder.
|
||||||
|
func (_u *IdempotencyRecordUpdate) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.Where(ps...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetScope sets the "scope" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetScope(v string) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetScope(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableScope sets the "scope" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableScope(v *string) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetScope(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetIdempotencyKeyHash(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetIdempotencyKeyHash(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRequestFingerprint sets the "request_fingerprint" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetRequestFingerprint(v string) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetRequestFingerprint(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRequestFingerprint(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus sets the "status" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetStatus(v string) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetStatus(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableStatus sets the "status" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableStatus(v *string) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetStatus(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetResponseStatus sets the "response_status" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetResponseStatus(v int) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.ResetResponseStatus()
|
||||||
|
_u.mutation.SetResponseStatus(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableResponseStatus sets the "response_status" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetResponseStatus(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddResponseStatus adds value to the "response_status" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) AddResponseStatus(v int) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.AddResponseStatus(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearResponseStatus clears the value of the "response_status" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) ClearResponseStatus() *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.ClearResponseStatus()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetResponseBody sets the "response_body" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetResponseBody(v string) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetResponseBody(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableResponseBody(v *string) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetResponseBody(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearResponseBody clears the value of the "response_body" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) ClearResponseBody() *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.ClearResponseBody()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetErrorReason sets the "error_reason" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetErrorReason(v string) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetErrorReason(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableErrorReason sets the "error_reason" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableErrorReason(v *string) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetErrorReason(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearErrorReason clears the value of the "error_reason" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) ClearErrorReason() *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.ClearErrorReason()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLockedUntil sets the "locked_until" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetLockedUntil(v time.Time) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetLockedUntil(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetLockedUntil(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearLockedUntil clears the value of the "locked_until" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) ClearLockedUntil() *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.ClearLockedUntil()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetExpiresAt(v time.Time) *IdempotencyRecordUpdate {
|
||||||
|
_u.mutation.SetExpiresAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetExpiresAt(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutation returns the IdempotencyRecordMutation object of the builder.
|
||||||
|
func (_u *IdempotencyRecordUpdate) Mutation() *IdempotencyRecordMutation {
|
||||||
|
return _u.mutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||||
|
func (_u *IdempotencyRecordUpdate) Save(ctx context.Context) (int, error) {
|
||||||
|
_u.defaults()
|
||||||
|
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveX is like Save, but panics if an error occurs.
|
||||||
|
func (_u *IdempotencyRecordUpdate) SaveX(ctx context.Context) int {
|
||||||
|
affected, err := _u.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return affected
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query.
|
||||||
|
func (_u *IdempotencyRecordUpdate) Exec(ctx context.Context) error {
|
||||||
|
_, err := _u.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_u *IdempotencyRecordUpdate) ExecX(ctx context.Context) {
|
||||||
|
if err := _u.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaults sets the default values of the builder before save.
|
||||||
|
func (_u *IdempotencyRecordUpdate) defaults() {
|
||||||
|
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||||
|
v := idempotencyrecord.UpdateDefaultUpdatedAt()
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check runs all checks and user-defined validators on the builder.
|
||||||
|
func (_u *IdempotencyRecordUpdate) check() error {
|
||||||
|
if v, ok := _u.mutation.Scope(); ok {
|
||||||
|
if err := idempotencyrecord.ScopeValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.IdempotencyKeyHash(); ok {
|
||||||
|
if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.RequestFingerprint(); ok {
|
||||||
|
if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.Status(); ok {
|
||||||
|
if err := idempotencyrecord.StatusValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.ErrorReason(); ok {
|
||||||
|
if err := idempotencyrecord.ErrorReasonValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_u *IdempotencyRecordUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||||
|
if err := _u.check(); err != nil {
|
||||||
|
return _node, err
|
||||||
|
}
|
||||||
|
_spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
|
||||||
|
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Scope(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.IdempotencyKeyHash(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.RequestFingerprint(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Status(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ResponseStatus(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedResponseStatus(); ok {
|
||||||
|
_spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ResponseStatusCleared() {
|
||||||
|
_spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ResponseBody(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ResponseBodyCleared() {
|
||||||
|
_spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ErrorReason(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ErrorReasonCleared() {
|
||||||
|
_spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.LockedUntil(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.LockedUntilCleared() {
|
||||||
|
_spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||||
|
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||||
|
err = &NotFoundError{idempotencyrecord.Label}
|
||||||
|
} else if sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_u.mutation.done = true
|
||||||
|
return _node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdempotencyRecordUpdateOne is the builder for updating a single IdempotencyRecord entity.
|
||||||
|
type IdempotencyRecordUpdateOne struct {
|
||||||
|
config
|
||||||
|
fields []string
|
||||||
|
hooks []Hook
|
||||||
|
mutation *IdempotencyRecordMutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetScope sets the "scope" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetScope(v string) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetScope(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableScope sets the "scope" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableScope(v *string) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetScope(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetIdempotencyKeyHash(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetIdempotencyKeyHash(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRequestFingerprint sets the "request_fingerprint" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetRequestFingerprint(v string) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetRequestFingerprint(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetRequestFingerprint(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus sets the "status" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetStatus(v string) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetStatus(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableStatus sets the "status" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableStatus(v *string) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetStatus(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetResponseStatus sets the "response_status" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetResponseStatus(v int) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.ResetResponseStatus()
|
||||||
|
_u.mutation.SetResponseStatus(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableResponseStatus sets the "response_status" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetResponseStatus(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddResponseStatus adds value to the "response_status" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) AddResponseStatus(v int) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.AddResponseStatus(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearResponseStatus clears the value of the "response_status" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) ClearResponseStatus() *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.ClearResponseStatus()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetResponseBody sets the "response_body" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetResponseBody(v string) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetResponseBody(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableResponseBody(v *string) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetResponseBody(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearResponseBody clears the value of the "response_body" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) ClearResponseBody() *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.ClearResponseBody()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetErrorReason sets the "error_reason" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetErrorReason(v string) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetErrorReason(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableErrorReason sets the "error_reason" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableErrorReason(v *string) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetErrorReason(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearErrorReason clears the value of the "error_reason" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) ClearErrorReason() *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.ClearErrorReason()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLockedUntil sets the "locked_until" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetLockedUntil(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetLockedUntil(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearLockedUntil clears the value of the "locked_until" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) ClearLockedUntil() *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.ClearLockedUntil()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExpiresAt sets the "expires_at" field.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.SetExpiresAt(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetExpiresAt(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutation returns the IdempotencyRecordMutation object of the builder.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) Mutation() *IdempotencyRecordMutation {
|
||||||
|
return _u.mutation
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where appends a list predicates to the IdempotencyRecordUpdate builder.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.mutation.Where(ps...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||||
|
// The default is selecting all fields defined in the entity schema.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) Select(field string, fields ...string) *IdempotencyRecordUpdateOne {
|
||||||
|
_u.fields = append([]string{field}, fields...)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save executes the query and returns the updated IdempotencyRecord entity.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) Save(ctx context.Context) (*IdempotencyRecord, error) {
|
||||||
|
_u.defaults()
|
||||||
|
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveX is like Save, but panics if an error occurs.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) SaveX(ctx context.Context) *IdempotencyRecord {
|
||||||
|
node, err := _u.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes the query on the entity.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) Exec(ctx context.Context) error {
|
||||||
|
_, err := _u.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecX is like Exec, but panics if an error occurs.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) ExecX(ctx context.Context) {
|
||||||
|
if err := _u.Exec(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaults sets the default values of the builder before save.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) defaults() {
|
||||||
|
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||||
|
v := idempotencyrecord.UpdateDefaultUpdatedAt()
|
||||||
|
_u.mutation.SetUpdatedAt(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check runs all checks and user-defined validators on the builder.
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) check() error {
|
||||||
|
if v, ok := _u.mutation.Scope(); ok {
|
||||||
|
if err := idempotencyrecord.ScopeValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.IdempotencyKeyHash(); ok {
|
||||||
|
if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.RequestFingerprint(); ok {
|
||||||
|
if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.Status(); ok {
|
||||||
|
if err := idempotencyrecord.StatusValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := _u.mutation.ErrorReason(); ok {
|
||||||
|
if err := idempotencyrecord.ErrorReasonValidator(v); err != nil {
|
||||||
|
return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_u *IdempotencyRecordUpdateOne) sqlSave(ctx context.Context) (_node *IdempotencyRecord, err error) {
|
||||||
|
if err := _u.check(); err != nil {
|
||||||
|
return _node, err
|
||||||
|
}
|
||||||
|
_spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
|
||||||
|
id, ok := _u.mutation.ID()
|
||||||
|
if !ok {
|
||||||
|
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdempotencyRecord.id" for update`)}
|
||||||
|
}
|
||||||
|
_spec.Node.ID.Value = id
|
||||||
|
if fields := _u.fields; len(fields) > 0 {
|
||||||
|
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID)
|
||||||
|
for _, f := range fields {
|
||||||
|
if !idempotencyrecord.ValidColumn(f) {
|
||||||
|
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||||
|
}
|
||||||
|
if f != idempotencyrecord.FieldID {
|
||||||
|
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||||
|
_spec.Predicate = func(selector *sql.Selector) {
|
||||||
|
for i := range ps {
|
||||||
|
ps[i](selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Scope(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.IdempotencyKeyHash(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.RequestFingerprint(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.Status(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ResponseStatus(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedResponseStatus(); ok {
|
||||||
|
_spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ResponseStatusCleared() {
|
||||||
|
_spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ResponseBody(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ResponseBodyCleared() {
|
||||||
|
_spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ErrorReason(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.ErrorReasonCleared() {
|
||||||
|
_spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.LockedUntil(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.LockedUntilCleared() {
|
||||||
|
_spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||||
|
_spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value)
|
||||||
|
}
|
||||||
|
_node = &IdempotencyRecord{config: _u.config}
|
||||||
|
_spec.Assign = _node.assignValues
|
||||||
|
_spec.ScanValues = _node.scanValues
|
||||||
|
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||||
|
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||||
|
err = &NotFoundError{idempotencyrecord.Label}
|
||||||
|
} else if sqlgraph.IsConstraintError(err) {
|
||||||
|
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_u.mutation.done = true
|
||||||
|
return _node, nil
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||||
@@ -276,6 +277,33 @@ func (f TraverseGroup) Traverse(ctx context.Context, q ent.Query) error {
|
|||||||
return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q)
|
return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||||
|
type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordQuery) (ent.Value, error)
|
||||||
|
|
||||||
|
// Query calls f(ctx, q).
|
||||||
|
func (f IdempotencyRecordFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||||
|
if q, ok := q.(*ent.IdempotencyRecordQuery); ok {
|
||||||
|
return f(ctx, q)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The TraverseIdempotencyRecord type is an adapter to allow the use of ordinary function as Traverser.
|
||||||
|
type TraverseIdempotencyRecord func(context.Context, *ent.IdempotencyRecordQuery) error
|
||||||
|
|
||||||
|
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||||
|
func (f TraverseIdempotencyRecord) Intercept(next ent.Querier) ent.Querier {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse calls f(ctx, q).
|
||||||
|
func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) error {
|
||||||
|
if q, ok := q.(*ent.IdempotencyRecordQuery); ok {
|
||||||
|
return f(ctx, q)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
|
||||||
|
}
|
||||||
|
|
||||||
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
|
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||||
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
|
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
|
||||||
|
|
||||||
@@ -644,6 +672,8 @@ func NewQuery(q ent.Query) (Query, error) {
|
|||||||
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
|
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
|
||||||
case *ent.GroupQuery:
|
case *ent.GroupQuery:
|
||||||
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
|
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
|
||||||
|
case *ent.IdempotencyRecordQuery:
|
||||||
|
return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil
|
||||||
case *ent.PromoCodeQuery:
|
case *ent.PromoCodeQuery:
|
||||||
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
|
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
|
||||||
case *ent.PromoCodeUsageQuery:
|
case *ent.PromoCodeUsageQuery:
|
||||||
|
|||||||
@@ -108,6 +108,8 @@ var (
|
|||||||
{Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
{Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
{Name: "overload_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "overload_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "temp_unschedulable_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "temp_unschedulable_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||||
{Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
{Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
{Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20},
|
{Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20},
|
||||||
@@ -121,7 +123,7 @@ var (
|
|||||||
ForeignKeys: []*schema.ForeignKey{
|
ForeignKeys: []*schema.ForeignKey{
|
||||||
{
|
{
|
||||||
Symbol: "accounts_proxies_proxy",
|
Symbol: "accounts_proxies_proxy",
|
||||||
Columns: []*schema.Column{AccountsColumns[25]},
|
Columns: []*schema.Column{AccountsColumns[27]},
|
||||||
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
RefColumns: []*schema.Column{ProxiesColumns[0]},
|
||||||
OnDelete: schema.SetNull,
|
OnDelete: schema.SetNull,
|
||||||
},
|
},
|
||||||
@@ -145,7 +147,7 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "account_proxy_id",
|
Name: "account_proxy_id",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[25]},
|
Columns: []*schema.Column{AccountsColumns[27]},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "account_priority",
|
Name: "account_priority",
|
||||||
@@ -177,6 +179,16 @@ var (
|
|||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{AccountsColumns[21]},
|
Columns: []*schema.Column{AccountsColumns[21]},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "account_platform_priority",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "account_priority_status",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "account_deleted_at",
|
Name: "account_deleted_at",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
@@ -376,6 +388,7 @@ var (
|
|||||||
{Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
|
{Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
|
||||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||||
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
|
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
|
||||||
@@ -419,7 +432,45 @@ var (
|
|||||||
{
|
{
|
||||||
Name: "group_sort_order",
|
Name: "group_sort_order",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{GroupsColumns[29]},
|
Columns: []*schema.Column{GroupsColumns[30]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// IdempotencyRecordsColumns holds the columns for the "idempotency_records" table.
|
||||||
|
IdempotencyRecordsColumns = []*schema.Column{
|
||||||
|
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||||
|
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
|
{Name: "scope", Type: field.TypeString, Size: 128},
|
||||||
|
{Name: "idempotency_key_hash", Type: field.TypeString, Size: 64},
|
||||||
|
{Name: "request_fingerprint", Type: field.TypeString, Size: 64},
|
||||||
|
{Name: "status", Type: field.TypeString, Size: 32},
|
||||||
|
{Name: "response_status", Type: field.TypeInt, Nullable: true},
|
||||||
|
{Name: "response_body", Type: field.TypeString, Nullable: true},
|
||||||
|
{Name: "error_reason", Type: field.TypeString, Nullable: true, Size: 128},
|
||||||
|
{Name: "locked_until", Type: field.TypeTime, Nullable: true},
|
||||||
|
{Name: "expires_at", Type: field.TypeTime},
|
||||||
|
}
|
||||||
|
// IdempotencyRecordsTable holds the schema information for the "idempotency_records" table.
|
||||||
|
IdempotencyRecordsTable = &schema.Table{
|
||||||
|
Name: "idempotency_records",
|
||||||
|
Columns: IdempotencyRecordsColumns,
|
||||||
|
PrimaryKey: []*schema.Column{IdempotencyRecordsColumns[0]},
|
||||||
|
Indexes: []*schema.Index{
|
||||||
|
{
|
||||||
|
Name: "idempotencyrecord_scope_idempotency_key_hash",
|
||||||
|
Unique: true,
|
||||||
|
Columns: []*schema.Column{IdempotencyRecordsColumns[3], IdempotencyRecordsColumns[4]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "idempotencyrecord_expires_at",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{IdempotencyRecordsColumns[11]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "idempotencyrecord_status_locked_until",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{IdempotencyRecordsColumns[6], IdempotencyRecordsColumns[10]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -771,6 +822,11 @@ var (
|
|||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "usagelog_group_id_created_at",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// UsersColumns holds the columns for the "users" table.
|
// UsersColumns holds the columns for the "users" table.
|
||||||
@@ -790,6 +846,8 @@ var (
|
|||||||
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||||
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
||||||
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
||||||
|
{Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
|
||||||
|
{Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0},
|
||||||
}
|
}
|
||||||
// UsersTable holds the schema information for the "users" table.
|
// UsersTable holds the schema information for the "users" table.
|
||||||
UsersTable = &schema.Table{
|
UsersTable = &schema.Table{
|
||||||
@@ -995,6 +1053,11 @@ var (
|
|||||||
Unique: false,
|
Unique: false,
|
||||||
Columns: []*schema.Column{UserSubscriptionsColumns[5]},
|
Columns: []*schema.Column{UserSubscriptionsColumns[5]},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "usersubscription_user_id_status_expires_at",
|
||||||
|
Unique: false,
|
||||||
|
Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[6], UserSubscriptionsColumns[5]},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "usersubscription_assigned_by",
|
Name: "usersubscription_assigned_by",
|
||||||
Unique: false,
|
Unique: false,
|
||||||
@@ -1021,6 +1084,7 @@ var (
|
|||||||
AnnouncementReadsTable,
|
AnnouncementReadsTable,
|
||||||
ErrorPassthroughRulesTable,
|
ErrorPassthroughRulesTable,
|
||||||
GroupsTable,
|
GroupsTable,
|
||||||
|
IdempotencyRecordsTable,
|
||||||
PromoCodesTable,
|
PromoCodesTable,
|
||||||
PromoCodeUsagesTable,
|
PromoCodeUsagesTable,
|
||||||
ProxiesTable,
|
ProxiesTable,
|
||||||
@@ -1066,6 +1130,9 @@ func init() {
|
|||||||
GroupsTable.Annotation = &entsql.Annotation{
|
GroupsTable.Annotation = &entsql.Annotation{
|
||||||
Table: "groups",
|
Table: "groups",
|
||||||
}
|
}
|
||||||
|
IdempotencyRecordsTable.Annotation = &entsql.Annotation{
|
||||||
|
Table: "idempotency_records",
|
||||||
|
}
|
||||||
PromoCodesTable.Annotation = &entsql.Annotation{
|
PromoCodesTable.Annotation = &entsql.Annotation{
|
||||||
Table: "promo_codes",
|
Table: "promo_codes",
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -27,6 +27,9 @@ type ErrorPassthroughRule func(*sql.Selector)
|
|||||||
// Group is the predicate function for group builders.
|
// Group is the predicate function for group builders.
|
||||||
type Group func(*sql.Selector)
|
type Group func(*sql.Selector)
|
||||||
|
|
||||||
|
// IdempotencyRecord is the predicate function for idempotencyrecord builders.
|
||||||
|
type IdempotencyRecord func(*sql.Selector)
|
||||||
|
|
||||||
// PromoCode is the predicate function for promocode builders.
|
// PromoCode is the predicate function for promocode builders.
|
||||||
type PromoCode func(*sql.Selector)
|
type PromoCode func(*sql.Selector)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||||
@@ -209,7 +210,7 @@ func init() {
|
|||||||
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
|
||||||
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
|
||||||
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
|
||||||
accountDescSessionWindowStatus := accountFields[21].Descriptor()
|
accountDescSessionWindowStatus := accountFields[23].Descriptor()
|
||||||
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
|
||||||
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
|
||||||
accountgroupFields := schema.AccountGroup{}.Fields()
|
accountgroupFields := schema.AccountGroup{}.Fields()
|
||||||
@@ -398,26 +399,65 @@ func init() {
|
|||||||
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
||||||
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
||||||
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
||||||
|
// groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
|
||||||
|
groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor()
|
||||||
|
// group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
|
||||||
|
group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64)
|
||||||
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
||||||
groupDescClaudeCodeOnly := groupFields[18].Descriptor()
|
groupDescClaudeCodeOnly := groupFields[19].Descriptor()
|
||||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||||
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
|
||||||
groupDescModelRoutingEnabled := groupFields[22].Descriptor()
|
groupDescModelRoutingEnabled := groupFields[23].Descriptor()
|
||||||
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
|
||||||
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
|
||||||
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
|
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
|
||||||
groupDescMcpXMLInject := groupFields[23].Descriptor()
|
groupDescMcpXMLInject := groupFields[24].Descriptor()
|
||||||
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
|
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
|
||||||
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
|
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
|
||||||
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
|
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
|
||||||
groupDescSupportedModelScopes := groupFields[24].Descriptor()
|
groupDescSupportedModelScopes := groupFields[25].Descriptor()
|
||||||
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
|
||||||
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
|
||||||
// groupDescSortOrder is the schema descriptor for sort_order field.
|
// groupDescSortOrder is the schema descriptor for sort_order field.
|
||||||
groupDescSortOrder := groupFields[25].Descriptor()
|
groupDescSortOrder := groupFields[26].Descriptor()
|
||||||
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
|
||||||
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
|
||||||
|
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||||
|
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||||
|
_ = idempotencyrecordMixinFields0
|
||||||
|
idempotencyrecordFields := schema.IdempotencyRecord{}.Fields()
|
||||||
|
_ = idempotencyrecordFields
|
||||||
|
// idempotencyrecordDescCreatedAt is the schema descriptor for created_at field.
|
||||||
|
idempotencyrecordDescCreatedAt := idempotencyrecordMixinFields0[0].Descriptor()
|
||||||
|
// idempotencyrecord.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
|
idempotencyrecord.DefaultCreatedAt = idempotencyrecordDescCreatedAt.Default.(func() time.Time)
|
||||||
|
// idempotencyrecordDescUpdatedAt is the schema descriptor for updated_at field.
|
||||||
|
idempotencyrecordDescUpdatedAt := idempotencyrecordMixinFields0[1].Descriptor()
|
||||||
|
// idempotencyrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||||
|
idempotencyrecord.DefaultUpdatedAt = idempotencyrecordDescUpdatedAt.Default.(func() time.Time)
|
||||||
|
// idempotencyrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||||
|
idempotencyrecord.UpdateDefaultUpdatedAt = idempotencyrecordDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||||
|
// idempotencyrecordDescScope is the schema descriptor for scope field.
|
||||||
|
idempotencyrecordDescScope := idempotencyrecordFields[0].Descriptor()
|
||||||
|
// idempotencyrecord.ScopeValidator is a validator for the "scope" field. It is called by the builders before save.
|
||||||
|
idempotencyrecord.ScopeValidator = idempotencyrecordDescScope.Validators[0].(func(string) error)
|
||||||
|
// idempotencyrecordDescIdempotencyKeyHash is the schema descriptor for idempotency_key_hash field.
|
||||||
|
idempotencyrecordDescIdempotencyKeyHash := idempotencyrecordFields[1].Descriptor()
|
||||||
|
// idempotencyrecord.IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save.
|
||||||
|
idempotencyrecord.IdempotencyKeyHashValidator = idempotencyrecordDescIdempotencyKeyHash.Validators[0].(func(string) error)
|
||||||
|
// idempotencyrecordDescRequestFingerprint is the schema descriptor for request_fingerprint field.
|
||||||
|
idempotencyrecordDescRequestFingerprint := idempotencyrecordFields[2].Descriptor()
|
||||||
|
// idempotencyrecord.RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save.
|
||||||
|
idempotencyrecord.RequestFingerprintValidator = idempotencyrecordDescRequestFingerprint.Validators[0].(func(string) error)
|
||||||
|
// idempotencyrecordDescStatus is the schema descriptor for status field.
|
||||||
|
idempotencyrecordDescStatus := idempotencyrecordFields[3].Descriptor()
|
||||||
|
// idempotencyrecord.StatusValidator is a validator for the "status" field. It is called by the builders before save.
|
||||||
|
idempotencyrecord.StatusValidator = idempotencyrecordDescStatus.Validators[0].(func(string) error)
|
||||||
|
// idempotencyrecordDescErrorReason is the schema descriptor for error_reason field.
|
||||||
|
idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor()
|
||||||
|
// idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
|
||||||
|
idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error)
|
||||||
promocodeFields := schema.PromoCode{}.Fields()
|
promocodeFields := schema.PromoCode{}.Fields()
|
||||||
_ = promocodeFields
|
_ = promocodeFields
|
||||||
// promocodeDescCode is the schema descriptor for code field.
|
// promocodeDescCode is the schema descriptor for code field.
|
||||||
@@ -918,6 +958,14 @@ func init() {
|
|||||||
userDescTotpEnabled := userFields[9].Descriptor()
|
userDescTotpEnabled := userFields[9].Descriptor()
|
||||||
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
||||||
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
||||||
|
// userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
|
||||||
|
userDescSoraStorageQuotaBytes := userFields[11].Descriptor()
|
||||||
|
// user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
|
||||||
|
user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64)
|
||||||
|
// userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field.
|
||||||
|
userDescSoraStorageUsedBytes := userFields[12].Descriptor()
|
||||||
|
// user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field.
|
||||||
|
user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64)
|
||||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||||
_ = userallowedgroupFields
|
_ = userallowedgroupFields
|
||||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||||
|
|||||||
@@ -164,6 +164,19 @@ func (Account) Fields() []ent.Field {
|
|||||||
Nillable().
|
Nillable().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||||
|
|
||||||
|
// temp_unschedulable_until: 临时不可调度状态解除时间
|
||||||
|
// 当命中临时不可调度规则时设置,在此时间前调度器应跳过该账号
|
||||||
|
field.Time("temp_unschedulable_until").
|
||||||
|
Optional().
|
||||||
|
Nillable().
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||||
|
|
||||||
|
// temp_unschedulable_reason: 临时不可调度原因,便于排障审计
|
||||||
|
field.String("temp_unschedulable_reason").
|
||||||
|
Optional().
|
||||||
|
Nillable().
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "text"}),
|
||||||
|
|
||||||
// session_window_*: 会话窗口相关字段
|
// session_window_*: 会话窗口相关字段
|
||||||
// 用于管理某些需要会话时间窗口的 API(如 Claude Pro)
|
// 用于管理某些需要会话时间窗口的 API(如 Claude Pro)
|
||||||
field.Time("session_window_start").
|
field.Time("session_window_start").
|
||||||
@@ -213,6 +226,9 @@ func (Account) Indexes() []ent.Index {
|
|||||||
index.Fields("rate_limited_at"), // 筛选速率限制账户
|
index.Fields("rate_limited_at"), // 筛选速率限制账户
|
||||||
index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间
|
index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间
|
||||||
index.Fields("overload_until"), // 筛选过载账户
|
index.Fields("overload_until"), // 筛选过载账户
|
||||||
index.Fields("deleted_at"), // 软删除查询优化
|
// 调度热路径复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐)
|
||||||
|
index.Fields("platform", "priority"),
|
||||||
|
index.Fields("priority", "status"),
|
||||||
|
index.Fields("deleted_at"), // 软删除查询优化
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,6 +105,10 @@ func (Group) Fields() []ent.Field {
|
|||||||
Nillable().
|
Nillable().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||||
|
|
||||||
|
// Sora 存储配额
|
||||||
|
field.Int64("sora_storage_quota_bytes").
|
||||||
|
Default(0),
|
||||||
|
|
||||||
// Claude Code 客户端限制 (added by migration 029)
|
// Claude Code 客户端限制 (added by migration 029)
|
||||||
field.Bool("claude_code_only").
|
field.Bool("claude_code_only").
|
||||||
Default(false).
|
Default(false).
|
||||||
|
|||||||
@@ -179,5 +179,7 @@ func (UsageLog) Indexes() []ent.Index {
|
|||||||
// 复合索引用于时间范围查询
|
// 复合索引用于时间范围查询
|
||||||
index.Fields("user_id", "created_at"),
|
index.Fields("user_id", "created_at"),
|
||||||
index.Fields("api_key_id", "created_at"),
|
index.Fields("api_key_id", "created_at"),
|
||||||
|
// 分组维度时间范围查询(线上由 SQL 迁移创建 group_id IS NOT NULL 的部分索引)
|
||||||
|
index.Fields("group_id", "created_at"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,6 +72,12 @@ func (User) Fields() []ent.Field {
|
|||||||
field.Time("totp_enabled_at").
|
field.Time("totp_enabled_at").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
|
||||||
|
// Sora 存储配额
|
||||||
|
field.Int64("sora_storage_quota_bytes").
|
||||||
|
Default(0),
|
||||||
|
field.Int64("sora_storage_used_bytes").
|
||||||
|
Default(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -108,6 +108,8 @@ func (UserSubscription) Indexes() []ent.Index {
|
|||||||
index.Fields("group_id"),
|
index.Fields("group_id"),
|
||||||
index.Fields("status"),
|
index.Fields("status"),
|
||||||
index.Fields("expires_at"),
|
index.Fields("expires_at"),
|
||||||
|
// 活跃订阅查询复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐)
|
||||||
|
index.Fields("user_id", "status", "expires_at"),
|
||||||
index.Fields("assigned_by"),
|
index.Fields("assigned_by"),
|
||||||
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅
|
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅
|
||||||
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
|
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ type Tx struct {
|
|||||||
ErrorPassthroughRule *ErrorPassthroughRuleClient
|
ErrorPassthroughRule *ErrorPassthroughRuleClient
|
||||||
// Group is the client for interacting with the Group builders.
|
// Group is the client for interacting with the Group builders.
|
||||||
Group *GroupClient
|
Group *GroupClient
|
||||||
|
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
|
||||||
|
IdempotencyRecord *IdempotencyRecordClient
|
||||||
// PromoCode is the client for interacting with the PromoCode builders.
|
// PromoCode is the client for interacting with the PromoCode builders.
|
||||||
PromoCode *PromoCodeClient
|
PromoCode *PromoCodeClient
|
||||||
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
|
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
|
||||||
@@ -192,6 +194,7 @@ func (tx *Tx) init() {
|
|||||||
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
|
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
|
||||||
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
|
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
|
||||||
tx.Group = NewGroupClient(tx.config)
|
tx.Group = NewGroupClient(tx.config)
|
||||||
|
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
|
||||||
tx.PromoCode = NewPromoCodeClient(tx.config)
|
tx.PromoCode = NewPromoCodeClient(tx.config)
|
||||||
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
|
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
|
||||||
tx.Proxy = NewProxyClient(tx.config)
|
tx.Proxy = NewProxyClient(tx.config)
|
||||||
|
|||||||
@@ -45,6 +45,10 @@ type User struct {
|
|||||||
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
||||||
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
||||||
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
||||||
|
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||||
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||||
|
// SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
|
||||||
|
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the UserQuery when eager-loading is set.
|
// The values are being populated by the UserQuery when eager-loading is set.
|
||||||
Edges UserEdges `json:"edges"`
|
Edges UserEdges `json:"edges"`
|
||||||
@@ -177,7 +181,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
|||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case user.FieldBalance:
|
case user.FieldBalance:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case user.FieldID, user.FieldConcurrency:
|
case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@@ -291,6 +295,18 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
|||||||
_m.TotpEnabledAt = new(time.Time)
|
_m.TotpEnabledAt = new(time.Time)
|
||||||
*_m.TotpEnabledAt = value.Time
|
*_m.TotpEnabledAt = value.Time
|
||||||
}
|
}
|
||||||
|
case user.FieldSoraStorageQuotaBytes:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.SoraStorageQuotaBytes = value.Int64
|
||||||
|
}
|
||||||
|
case user.FieldSoraStorageUsedBytes:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.SoraStorageUsedBytes = value.Int64
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -424,6 +440,12 @@ func (_m *User) String() string {
|
|||||||
builder.WriteString("totp_enabled_at=")
|
builder.WriteString("totp_enabled_at=")
|
||||||
builder.WriteString(v.Format(time.ANSIC))
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
}
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("sora_storage_quota_bytes=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("sora_storage_used_bytes=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,10 @@ const (
|
|||||||
FieldTotpEnabled = "totp_enabled"
|
FieldTotpEnabled = "totp_enabled"
|
||||||
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
||||||
FieldTotpEnabledAt = "totp_enabled_at"
|
FieldTotpEnabledAt = "totp_enabled_at"
|
||||||
|
// FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
|
||||||
|
FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
|
||||||
|
// FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
|
||||||
|
FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -152,6 +156,8 @@ var Columns = []string{
|
|||||||
FieldTotpSecretEncrypted,
|
FieldTotpSecretEncrypted,
|
||||||
FieldTotpEnabled,
|
FieldTotpEnabled,
|
||||||
FieldTotpEnabledAt,
|
FieldTotpEnabledAt,
|
||||||
|
FieldSoraStorageQuotaBytes,
|
||||||
|
FieldSoraStorageUsedBytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -208,6 +214,10 @@ var (
|
|||||||
DefaultNotes string
|
DefaultNotes string
|
||||||
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
||||||
DefaultTotpEnabled bool
|
DefaultTotpEnabled bool
|
||||||
|
// DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
|
||||||
|
DefaultSoraStorageQuotaBytes int64
|
||||||
|
// DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
|
||||||
|
DefaultSoraStorageUsedBytes int64
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the User queries.
|
// OrderOption defines the ordering options for the User queries.
|
||||||
@@ -288,6 +298,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
|
||||||
|
func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
|
||||||
|
func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User {
|
|||||||
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
|
||||||
|
func SoraStorageQuotaBytes(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
|
||||||
|
func SoraStorageUsedBytes(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.User {
|
func CreatedAtEQ(v time.Time) predicate.User {
|
||||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User {
|
|||||||
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesEQ(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesGT(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesGTE(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesLT(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
|
||||||
|
func SoraStorageQuotaBytesLTE(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
|
||||||
|
func SoraStorageUsedBytesEQ(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
|
||||||
|
func SoraStorageUsedBytesNEQ(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
|
||||||
|
func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
|
||||||
|
func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
|
||||||
|
func SoraStorageUsedBytesGT(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
|
||||||
|
func SoraStorageUsedBytesGTE(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
|
||||||
|
func SoraStorageUsedBytesLT(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
|
||||||
|
func SoraStorageUsedBytesLTE(v int64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.User {
|
func HasAPIKeys() predicate.User {
|
||||||
return predicate.User(func(s *sql.Selector) {
|
return predicate.User(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -210,6 +210,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
|
||||||
|
_c.mutation.SetSoraStorageQuotaBytes(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetSoraStorageQuotaBytes(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||||
|
func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
|
||||||
|
_c.mutation.SetSoraStorageUsedBytes(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetSoraStorageUsedBytes(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -424,6 +452,14 @@ func (_c *UserCreate) defaults() error {
|
|||||||
v := user.DefaultTotpEnabled
|
v := user.DefaultTotpEnabled
|
||||||
_c.mutation.SetTotpEnabled(v)
|
_c.mutation.SetTotpEnabled(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
||||||
|
v := user.DefaultSoraStorageQuotaBytes
|
||||||
|
_c.mutation.SetSoraStorageQuotaBytes(v)
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
|
||||||
|
v := user.DefaultSoraStorageUsedBytes
|
||||||
|
_c.mutation.SetSoraStorageUsedBytes(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -487,6 +523,12 @@ func (_c *UserCreate) check() error {
|
|||||||
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||||
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
|
||||||
|
return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
|
||||||
|
return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -570,6 +612,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||||
_node.TotpEnabledAt = &value
|
_node.TotpEnabledAt = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
_node.SoraStorageQuotaBytes = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
|
||||||
|
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||||
|
_node.SoraStorageUsedBytes = value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -956,6 +1006,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
|
||||||
|
u.Set(user.FieldSoraStorageQuotaBytes, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldSoraStorageQuotaBytes)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
|
||||||
|
u.Add(user.FieldSoraStorageQuotaBytes, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||||
|
func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
|
||||||
|
u.Set(user.FieldSoraStorageUsedBytes, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldSoraStorageUsedBytes)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
||||||
|
func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
|
||||||
|
u.Add(user.FieldSoraStorageUsedBytes, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -1218,6 +1304,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetSoraStorageQuotaBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddSoraStorageQuotaBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateSoraStorageQuotaBytes()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||||
|
func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetSoraStorageUsedBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
||||||
|
func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddSoraStorageUsedBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateSoraStorageUsedBytes()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -1646,6 +1774,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetSoraStorageQuotaBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
|
||||||
|
func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddSoraStorageQuotaBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateSoraStorageQuotaBytes()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||||
|
func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetSoraStorageUsedBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
|
||||||
|
func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddSoraStorageUsedBytes(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateSoraStorageUsedBytes()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -242,6 +242,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
|
||||||
|
_u.mutation.ResetSoraStorageQuotaBytes()
|
||||||
|
_u.mutation.SetSoraStorageQuotaBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSoraStorageQuotaBytes(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
||||||
|
func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
|
||||||
|
_u.mutation.AddSoraStorageQuotaBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||||
|
func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
|
||||||
|
_u.mutation.ResetSoraStorageUsedBytes()
|
||||||
|
_u.mutation.SetSoraStorageUsedBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSoraStorageUsedBytes(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
|
||||||
|
func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
|
||||||
|
_u.mutation.AddSoraStorageUsedBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -709,6 +751,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.TotpEnabledAtCleared() {
|
if _u.mutation.TotpEnabledAtCleared() {
|
||||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
|
||||||
|
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
|
||||||
|
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1352,6 +1406,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
|
||||||
|
func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
|
||||||
|
_u.mutation.ResetSoraStorageQuotaBytes()
|
||||||
|
_u.mutation.SetSoraStorageQuotaBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSoraStorageQuotaBytes(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
|
||||||
|
func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
|
||||||
|
_u.mutation.AddSoraStorageQuotaBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
|
||||||
|
func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
|
||||||
|
_u.mutation.ResetSoraStorageUsedBytes()
|
||||||
|
_u.mutation.SetSoraStorageUsedBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetSoraStorageUsedBytes(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
|
||||||
|
func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
|
||||||
|
_u.mutation.AddSoraStorageUsedBytes(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1849,6 +1945,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
|||||||
if _u.mutation.TotpEnabledAtCleared() {
|
if _u.mutation.TotpEnabledAtCleared() {
|
||||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
|
||||||
|
_spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
|
||||||
|
_spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
|
||||||
|
_spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -7,7 +7,11 @@ require (
|
|||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||||
github.com/alitto/pond/v2 v2.6.2
|
github.com/alitto/pond/v2 v2.6.2
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||||
github.com/cespare/xxhash/v2 v2.3.0
|
github.com/cespare/xxhash/v2 v2.3.0
|
||||||
|
github.com/coder/websocket v1.8.14
|
||||||
github.com/dgraph-io/ristretto v0.2.0
|
github.com/dgraph-io/ristretto v0.2.0
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||||
@@ -34,6 +38,8 @@ require (
|
|||||||
golang.org/x/net v0.49.0
|
golang.org/x/net v0.49.0
|
||||||
golang.org/x/sync v0.19.0
|
golang.org/x/sync v0.19.0
|
||||||
golang.org/x/term v0.40.0
|
golang.org/x/term v0.40.0
|
||||||
|
google.golang.org/grpc v1.75.1
|
||||||
|
google.golang.org/protobuf v1.36.10
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
modernc.org/sqlite v1.44.3
|
modernc.org/sqlite v1.44.3
|
||||||
@@ -47,6 +53,22 @@ require (
|
|||||||
github.com/agext/levenshtein v1.2.3 // indirect
|
github.com/agext/levenshtein v1.2.3 // indirect
|
||||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||||
|
github.com/aws/smithy-go v1.24.1 // indirect
|
||||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||||
@@ -146,7 +168,6 @@ require (
|
|||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||||
go.opentelemetry.io/otel v1.37.0 // indirect
|
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||||
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
|
||||||
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||||
go.uber.org/atomic v1.10.0 // indirect
|
go.uber.org/atomic v1.10.0 // indirect
|
||||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||||
@@ -156,8 +177,7 @@ require (
|
|||||||
golang.org/x/mod v0.32.0 // indirect
|
golang.org/x/mod v0.32.0 // indirect
|
||||||
golang.org/x/sys v0.41.0 // indirect
|
golang.org/x/sys v0.41.0 // indirect
|
||||||
golang.org/x/text v0.34.0 // indirect
|
golang.org/x/text v0.34.0 // indirect
|
||||||
google.golang.org/grpc v1.75.1 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
|
||||||
google.golang.org/protobuf v1.36.10 // indirect
|
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
modernc.org/libc v1.67.6 // indirect
|
modernc.org/libc v1.67.6 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
|||||||
@@ -22,6 +22,44 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
|||||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
|
||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.32.10/go.mod h1:2rUIOnA2JaiqYmSKYmRJlcMWy6qTj1vuRFscppSBMcw=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE=
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 h1:Ii4s+Sq3yDfaMLpjrJsqD6SmG/Wq/P5L/hw2qa78UAY=
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18/go.mod h1:6x81qnY++ovptLE6nWQeWrpXxbnlIex+4H4eYYGcqfc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 h1:eZioDaZGJ0tMM4gzmkNIO2aAoQd+je7Ug7TkvAzlmkU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18/go.mod h1:CCXwUKAJdoWr6/NcxZ+zsiPr6oH/Q5aTooRGYieAyj4=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 h1:fJvQ5mIBVfKtiyx0AHY6HeWcRX5LGANLpq8SVR+Uazs=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10/go.mod h1:Kzm5e6OmNH8VMkgK9t+ry5jEih4Y8whqs+1hrkxim1I=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 h1:LTRCYFlnnKFlKsyIQxKhJuDuA3ZkrDQMRYm6rXiHlLY=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18/go.mod h1:XhwkgGG6bHSd00nO/mexWTcTjgd6PjuvWQMqSn2UaEk=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 h1:/A/xDuZAVD2BpsS2fftFRo/NoEKQJ8YTnJDEHBy2Gtg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18/go.mod h1:hWe9b4f+djUQGmyiGEeOnZv69dtMSgpDRIvNMvuvzvY=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 h1:M1A9AjcFwlxTLuf0Faj88L8Iqw0n/AJHjpZTQzMMsSc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2/go.mod h1:KsdTV6Q9WKUZm2mNJnUFmIoXfZux91M3sr/a4REX8e0=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 h1:MzORe+J94I+hYu2a6XmV5yC9huoTv8NRcCrUNedDypQ=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/signin v1.0.6/go.mod h1:hXzcHLARD7GeWnifd8j9RWqtfIgxj4/cAtIVIK7hg8g=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 h1:7oGD8KPfBOJGXiCoRKrrrQkbvCp8N++u36hrLMPey6o=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11/go.mod h1:0DO9B5EUJQlIDif+XJRWCljZRKsAFKh3gpFz7UnDtOo=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||||
|
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||||
|
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||||
@@ -56,6 +94,12 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
|
|||||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
|
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||||
|
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||||
|
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||||
|
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||||
|
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||||
@@ -127,6 +171,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
|||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
@@ -190,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||||
|
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
@@ -223,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
|||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||||
|
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
@@ -274,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
|||||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
|
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||||
|
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||||
@@ -344,6 +396,8 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
|
|||||||
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||||
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||||
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||||
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
|
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
|
||||||
@@ -399,6 +453,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
|
|||||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
||||||
|
|||||||
@@ -364,6 +364,8 @@ type GatewayConfig struct {
|
|||||||
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
|
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
|
||||||
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
|
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
|
||||||
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
|
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
|
||||||
|
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
|
||||||
|
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
|
||||||
|
|
||||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||||
@@ -450,6 +452,101 @@ type GatewayConfig struct {
|
|||||||
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
|
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
||||||
|
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
||||||
|
type GatewayOpenAIWSConfig struct {
|
||||||
|
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||||
|
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||||
|
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
||||||
|
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||||
|
// Enabled: 全局总开关(默认 true)
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
// OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS
|
||||||
|
OAuthEnabled bool `mapstructure:"oauth_enabled"`
|
||||||
|
// APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS
|
||||||
|
APIKeyEnabled bool `mapstructure:"apikey_enabled"`
|
||||||
|
// ForceHTTP: 全局强制 HTTP(用于紧急回滚)
|
||||||
|
ForceHTTP bool `mapstructure:"force_http"`
|
||||||
|
// AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false)
|
||||||
|
AllowStoreRecovery bool `mapstructure:"allow_store_recovery"`
|
||||||
|
// IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true)
|
||||||
|
IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"`
|
||||||
|
// StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off)
|
||||||
|
// - strict: 强制新建连接(隔离优先)
|
||||||
|
// - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中)
|
||||||
|
// - off: 不强制新建连接(复用优先)
|
||||||
|
StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"`
|
||||||
|
// StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离)
|
||||||
|
// 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。
|
||||||
|
StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"`
|
||||||
|
// PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false)
|
||||||
|
PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"`
|
||||||
|
|
||||||
|
// Feature 开关:v2 优先于 v1
|
||||||
|
ResponsesWebsockets bool `mapstructure:"responses_websockets"`
|
||||||
|
ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"`
|
||||||
|
|
||||||
|
// 连接池参数
|
||||||
|
MaxConnsPerAccount int `mapstructure:"max_conns_per_account"`
|
||||||
|
MinIdlePerAccount int `mapstructure:"min_idle_per_account"`
|
||||||
|
MaxIdlePerAccount int `mapstructure:"max_idle_per_account"`
|
||||||
|
// DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限
|
||||||
|
DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"`
|
||||||
|
// OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor))
|
||||||
|
OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"`
|
||||||
|
// APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor))
|
||||||
|
APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
|
||||||
|
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
|
||||||
|
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
|
||||||
|
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
|
||||||
|
PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
|
||||||
|
QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
|
||||||
|
// EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数)
|
||||||
|
EventFlushBatchSize int `mapstructure:"event_flush_batch_size"`
|
||||||
|
// EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发
|
||||||
|
EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"`
|
||||||
|
// PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒)
|
||||||
|
PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"`
|
||||||
|
// FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却
|
||||||
|
FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"`
|
||||||
|
// RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避
|
||||||
|
RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"`
|
||||||
|
// RetryBackoffMaxMS: WS 重试最大退避(毫秒)
|
||||||
|
RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"`
|
||||||
|
// RetryJitterRatio: WS 重试退避抖动比例(0-1)
|
||||||
|
RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"`
|
||||||
|
// RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制
|
||||||
|
RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"`
|
||||||
|
// PayloadLogSampleRate: payload_schema 日志采样率(0-1)
|
||||||
|
PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
|
||||||
|
|
||||||
|
// 账号调度与粘连参数
|
||||||
|
LBTopK int `mapstructure:"lb_top_k"`
|
||||||
|
// StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL
|
||||||
|
StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"`
|
||||||
|
// SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key”
|
||||||
|
SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"`
|
||||||
|
// SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL)
|
||||||
|
SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"`
|
||||||
|
// MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接
|
||||||
|
MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"`
|
||||||
|
// StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL
|
||||||
|
StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"`
|
||||||
|
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
|
||||||
|
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
|
||||||
|
|
||||||
|
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。
|
||||||
|
type GatewayOpenAIWSSchedulerScoreWeights struct {
|
||||||
|
Priority float64 `mapstructure:"priority"`
|
||||||
|
Load float64 `mapstructure:"load"`
|
||||||
|
Queue float64 `mapstructure:"queue"`
|
||||||
|
ErrorRate float64 `mapstructure:"error_rate"`
|
||||||
|
TTFT float64 `mapstructure:"ttft"`
|
||||||
|
}
|
||||||
|
|
||||||
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
||||||
type GatewayUsageRecordConfig struct {
|
type GatewayUsageRecordConfig struct {
|
||||||
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
|
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
|
||||||
@@ -886,6 +983,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
|
|||||||
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
|
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
|
||||||
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
|
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
|
||||||
|
|
||||||
|
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
|
||||||
|
// 新键未配置(<=0)时回退旧键;新键优先。
|
||||||
|
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||||
|
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||||
|
}
|
||||||
|
|
||||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||||
if cfg.Totp.EncryptionKey == "" {
|
if cfg.Totp.EncryptionKey == "" {
|
||||||
@@ -945,7 +1048,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||||
viper.SetDefault("server.trusted_proxies", []string{})
|
viper.SetDefault("server.trusted_proxies", []string{})
|
||||||
viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
|
viper.SetDefault("server.max_request_body_size", int64(256*1024*1024))
|
||||||
// H2C 默认配置
|
// H2C 默认配置
|
||||||
viper.SetDefault("server.h2c.enabled", false)
|
viper.SetDefault("server.h2c.enabled", false)
|
||||||
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
|
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
|
||||||
@@ -1088,9 +1191,9 @@ func setDefaults() {
|
|||||||
// RateLimit
|
// RateLimit
|
||||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||||
|
|
||||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置
|
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
|
||||||
viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json")
|
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
|
||||||
viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256")
|
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
|
||||||
viper.SetDefault("pricing.data_dir", "./data")
|
viper.SetDefault("pricing.data_dir", "./data")
|
||||||
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
|
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
|
||||||
viper.SetDefault("pricing.update_interval_hours", 24)
|
viper.SetDefault("pricing.update_interval_hours", 24)
|
||||||
@@ -1157,9 +1260,55 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||||
viper.SetDefault("gateway.force_codex_cli", false)
|
viper.SetDefault("gateway.force_codex_cli", false)
|
||||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||||
|
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||||
|
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||||
|
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
||||||
|
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||||
|
viper.SetDefault("gateway.openai_ws.allow_store_recovery", false)
|
||||||
|
viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict")
|
||||||
|
viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false)
|
||||||
|
viper.SetDefault("gateway.openai_ws.responses_websockets", false)
|
||||||
|
viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128)
|
||||||
|
viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4)
|
||||||
|
viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12)
|
||||||
|
viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0)
|
||||||
|
viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0)
|
||||||
|
viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10)
|
||||||
|
viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900)
|
||||||
|
viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120)
|
||||||
|
viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7)
|
||||||
|
viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64)
|
||||||
|
viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1)
|
||||||
|
viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10)
|
||||||
|
viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300)
|
||||||
|
viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30)
|
||||||
|
viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120)
|
||||||
|
viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000)
|
||||||
|
viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
|
||||||
|
viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
|
||||||
|
viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
|
||||||
|
viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
|
||||||
|
viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
|
||||||
|
viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true)
|
||||||
|
viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600)
|
||||||
|
viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600)
|
||||||
|
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0)
|
||||||
|
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0)
|
||||||
|
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
|
||||||
|
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
|
||||||
|
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
|
||||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||||
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
|
||||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||||
@@ -1747,6 +1896,118 @@ func (c *Config) Validate() error {
|
|||||||
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
|
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
|
||||||
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
|
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
|
||||||
}
|
}
|
||||||
|
// 兼容旧键 sticky_previous_response_ttl_seconds
|
||||||
|
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||||
|
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 &&
|
||||||
|
c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative")
|
||||||
|
}
|
||||||
|
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||||
|
switch mode {
|
||||||
|
case "off", "shared", "dedicated":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||||
|
switch mode {
|
||||||
|
case "strict", "adaptive", "off":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.LBTopK <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 ||
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
|
||||||
|
}
|
||||||
|
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load +
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue +
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate +
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT
|
||||||
|
if weightSum <= 0 {
|
||||||
|
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
|
||||||
|
}
|
||||||
if c.Gateway.MaxLineSize < 0 {
|
if c.Gateway.MaxLineSize < 0 {
|
||||||
return fmt.Errorf("gateway.max_line_size must be non-negative")
|
return fmt.Errorf("gateway.max_line_size must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func resetViperWithJWTSecret(t *testing.T) {
|
func resetViperWithJWTSecret(t *testing.T) {
|
||||||
@@ -75,6 +76,103 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Gateway.OpenAIWS.Enabled {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true")
|
||||||
|
}
|
||||||
|
if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.ResponsesWebsockets {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false")
|
||||||
|
}
|
||||||
|
if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds)
|
||||||
|
}
|
||||||
|
if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true")
|
||||||
|
}
|
||||||
|
if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true")
|
||||||
|
}
|
||||||
|
if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS)
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate)
|
||||||
|
}
|
||||||
|
if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict")
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||||
|
}
|
||||||
|
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
||||||
|
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0")
|
||||||
|
t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200")
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 {
|
||||||
|
t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
|
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
|
||||||
resetViperWithJWTSecret(t)
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
@@ -993,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
|
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
|
||||||
wantErr: "gateway.stream_keepalive_interval",
|
wantErr: "gateway.stream_keepalive_interval",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "gateway openai ws oauth max conns factor",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.oauth_max_conns_factor",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gateway openai ws apikey max conns factor",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.apikey_max_conns_factor",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "gateway stream data interval range",
|
name: "gateway stream data interval range",
|
||||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
|
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
|
||||||
@@ -1174,6 +1282,165 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
||||||
|
buildValid := func(t *testing.T) *Config {
|
||||||
|
t.Helper()
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
cfg, err := Load()
|
||||||
|
require.NoError(t, err)
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) {
|
||||||
|
cfg := buildValid(t)
|
||||||
|
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
|
||||||
|
cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200
|
||||||
|
|
||||||
|
require.NoError(t, cfg.Validate())
|
||||||
|
require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
|
||||||
|
})
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
mutate func(*Config)
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "max_conns_per_account 必须为正数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.max_conns_per_account",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "min_idle_per_account 不能为负数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 },
|
||||||
|
wantErr: "gateway.openai_ws.min_idle_per_account",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max_idle_per_account 不能为负数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 },
|
||||||
|
wantErr: "gateway.openai_ws.max_idle_per_account",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "min_idle_per_account 不能大于 max_idle_per_account",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.Gateway.OpenAIWS.MinIdlePerAccount = 3
|
||||||
|
c.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||||||
|
},
|
||||||
|
wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max_idle_per_account 不能大于 max_conns_per_account",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||||||
|
c.Gateway.OpenAIWS.MinIdlePerAccount = 1
|
||||||
|
c.Gateway.OpenAIWS.MaxIdlePerAccount = 3
|
||||||
|
},
|
||||||
|
wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dial_timeout_seconds 必须为正数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.dial_timeout_seconds",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "read_timeout_seconds 必须为正数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.read_timeout_seconds",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "write_timeout_seconds 必须为正数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.write_timeout_seconds",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pool_target_utilization 必须在 (0,1]",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.pool_target_utilization",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "queue_limit_per_conn 必须为正数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.queue_limit_per_conn",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback_cooldown_seconds 不能为负数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 },
|
||||||
|
wantErr: "gateway.openai_ws.fallback_cooldown_seconds",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "store_disabled_conn_mode 必须为 strict|adaptive|off",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" },
|
||||||
|
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||||
|
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "payload_log_sample_rate 必须在 [0,1] 范围内",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 },
|
||||||
|
wantErr: "gateway.openai_ws.payload_log_sample_rate",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retry_total_budget_ms 不能为负数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 },
|
||||||
|
wantErr: "gateway.openai_ws.retry_total_budget_ms",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lb_top_k 必须为正数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.lb_top_k",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sticky_session_ttl_seconds 必须为正数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 },
|
||||||
|
wantErr: "gateway.openai_ws.sticky_session_ttl_seconds",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sticky_response_id_ttl_seconds 必须为正数",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
|
||||||
|
c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0
|
||||||
|
},
|
||||||
|
wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sticky_previous_response_ttl_seconds 不能为负数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 },
|
||||||
|
wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "scheduler_score_weights 不能为负数",
|
||||||
|
mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 },
|
||||||
|
wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "scheduler_score_weights 不能全为 0",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0
|
||||||
|
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0
|
||||||
|
},
|
||||||
|
wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
cfg := buildValid(t)
|
||||||
|
tc.mutate(cfg)
|
||||||
|
|
||||||
|
err := cfg.Validate()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), tc.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
|
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
|
||||||
resetViperWithJWTSecret(t)
|
resetViperWithJWTSecret(t)
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||||
// Gemini 3.1 image preview 映射
|
// Gemini 3.1 image preview 映射
|
||||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||||
|
// Gemini 3 image 兼容映射(向 3.1 image 迁移)
|
||||||
|
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||||
|
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
|
||||||
// 其他官方模型
|
// 其他官方模型
|
||||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||||
|
|||||||
24
backend/internal/domain/constants_test.go
Normal file
24
backend/internal/domain/constants_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := map[string]string{
|
||||||
|
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||||
|
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||||
|
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||||
|
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
|
||||||
|
}
|
||||||
|
|
||||||
|
for from, want := range cases {
|
||||||
|
got, ok := DefaultAntigravityModelMapping[from]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected mapping for %q to exist", from)
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1337,6 +1337,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
|
|||||||
response.Success(c, stats)
|
response.Success(c, stats)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchTodayStatsRequest 批量今日统计请求体。
|
||||||
|
type BatchTodayStatsRequest struct {
|
||||||
|
AccountIDs []int64 `json:"account_ids" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBatchTodayStats 批量获取多个账号的今日统计。
|
||||||
|
// POST /api/v1/admin/accounts/today-stats/batch
|
||||||
|
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
|
||||||
|
var req BatchTodayStatsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.AccountIDs) == 0 {
|
||||||
|
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"stats": stats})
|
||||||
|
}
|
||||||
|
|
||||||
// SetSchedulableRequest represents the request body for setting schedulable status
|
// SetSchedulableRequest represents the request body for setting schedulable status
|
||||||
type SetSchedulableRequest struct {
|
type SetSchedulableRequest struct {
|
||||||
Schedulable bool `json:"schedulable"`
|
Schedulable bool `json:"schedulable"`
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
@@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
|||||||
|
|
||||||
// GetUsageTrend handles getting usage trend data
|
// GetUsageTrend handles getting usage trend data
|
||||||
// GET /api/v1/admin/dashboard/trend
|
// GET /api/v1/admin/dashboard/trend
|
||||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
|
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
|
||||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||||
startTime, endTime := parseTimeRange(c)
|
startTime, endTime := parseTimeRange(c)
|
||||||
granularity := c.DefaultQuery("granularity", "day")
|
granularity := c.DefaultQuery("granularity", "day")
|
||||||
@@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID, accountID, groupID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
var model string
|
var model string
|
||||||
|
var requestType *int16
|
||||||
var stream *bool
|
var stream *bool
|
||||||
var billingType *int8
|
var billingType *int8
|
||||||
|
|
||||||
@@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
if modelStr := c.Query("model"); modelStr != "" {
|
if modelStr := c.Query("model"); modelStr != "" {
|
||||||
model = modelStr
|
model = modelStr
|
||||||
}
|
}
|
||||||
if streamStr := c.Query("stream"); streamStr != "" {
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := int16(parsed)
|
||||||
|
requestType = &value
|
||||||
|
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||||
stream = &streamVal
|
stream = &streamVal
|
||||||
|
} else {
|
||||||
|
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||||
@@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get usage trend")
|
response.Error(c, 500, "Failed to get usage trend")
|
||||||
return
|
return
|
||||||
@@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
|||||||
|
|
||||||
// GetModelStats handles getting model usage statistics
|
// GetModelStats handles getting model usage statistics
|
||||||
// GET /api/v1/admin/dashboard/models
|
// GET /api/v1/admin/dashboard/models
|
||||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
|
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
|
||||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||||
startTime, endTime := parseTimeRange(c)
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
|
||||||
// Parse optional filter params
|
// Parse optional filter params
|
||||||
var userID, apiKeyID, accountID, groupID int64
|
var userID, apiKeyID, accountID, groupID int64
|
||||||
|
var requestType *int16
|
||||||
var stream *bool
|
var stream *bool
|
||||||
var billingType *int8
|
var billingType *int8
|
||||||
|
|
||||||
@@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
groupID = id
|
groupID = id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if streamStr := c.Query("stream"); streamStr != "" {
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := int16(parsed)
|
||||||
|
requestType = &value
|
||||||
|
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||||
stream = &streamVal
|
stream = &streamVal
|
||||||
|
} else {
|
||||||
|
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||||
@@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get model statistics")
|
response.Error(c, 500, "Failed to get model statistics")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -0,0 +1,132 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardUsageRepoCapture struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
trendRequestType *int16
|
||||||
|
trendStream *bool
|
||||||
|
modelRequestType *int16
|
||||||
|
modelStream *bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
granularity string,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
model string,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.TrendDataPoint, error) {
|
||||||
|
s.trendRequestType = requestType
|
||||||
|
s.trendStream = stream
|
||||||
|
return []usagestats.TrendDataPoint{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
userID, apiKeyID, accountID, groupID int64,
|
||||||
|
requestType *int16,
|
||||||
|
stream *bool,
|
||||||
|
billingType *int8,
|
||||||
|
) ([]usagestats.ModelStat, error) {
|
||||||
|
s.modelRequestType = requestType
|
||||||
|
s.modelStream = stream
|
||||||
|
return []usagestats.ModelStat{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
|
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||||
|
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardTrendRequestTypePriority(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.NotNil(t, repo.trendRequestType)
|
||||||
|
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
|
||||||
|
require.Nil(t, repo.trendStream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardTrendInvalidRequestType(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardTrendInvalidStream(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.NotNil(t, repo.modelRequestType)
|
||||||
|
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
|
||||||
|
require.Nil(t, repo.modelStream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||||
|
repo := &dashboardUsageRepoCapture{}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
523
backend/internal/handler/admin/data_management_handler.go
Normal file
523
backend/internal/handler/admin/data_management_handler.go
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DataManagementHandler struct {
|
||||||
|
dataManagementService *service.DataManagementService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
|
||||||
|
return &DataManagementHandler{dataManagementService: dataManagementService}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestS3ConnectionRequest struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Region string `json:"region" binding:"required"`
|
||||||
|
Bucket string `json:"bucket" binding:"required"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKey string `json:"secret_access_key"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
UseSSL bool `json:"use_ssl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateBackupJobRequest struct {
|
||||||
|
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
|
||||||
|
UploadToS3 bool `json:"upload_to_s3"`
|
||||||
|
S3ProfileID string `json:"s3_profile_id"`
|
||||||
|
PostgresID string `json:"postgres_profile_id"`
|
||||||
|
RedisID string `json:"redis_profile_id"`
|
||||||
|
IdempotencyKey string `json:"idempotency_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateSourceProfileRequest struct {
|
||||||
|
ProfileID string `json:"profile_id" binding:"required"`
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
|
||||||
|
SetActive bool `json:"set_active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateSourceProfileRequest struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateS3ProfileRequest struct {
|
||||||
|
ProfileID string `json:"profile_id" binding:"required"`
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKey string `json:"secret_access_key"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
UseSSL bool `json:"use_ssl"`
|
||||||
|
SetActive bool `json:"set_active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateS3ProfileRequest struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKey string `json:"secret_access_key"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
UseSSL bool `json:"use_ssl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
|
||||||
|
health := h.getAgentHealth(c)
|
||||||
|
payload := gin.H{
|
||||||
|
"enabled": health.Enabled,
|
||||||
|
"reason": health.Reason,
|
||||||
|
"socket_path": health.SocketPath,
|
||||||
|
}
|
||||||
|
if health.Agent != nil {
|
||||||
|
payload["agent"] = gin.H{
|
||||||
|
"status": health.Agent.Status,
|
||||||
|
"version": health.Agent.Version,
|
||||||
|
"uptime_seconds": health.Agent.UptimeSeconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response.Success(c, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
|
||||||
|
var req service.DataManagementConfig
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) TestS3(c *gin.Context) {
|
||||||
|
var req TestS3ConnectionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
|
||||||
|
Enabled: true,
|
||||||
|
Endpoint: req.Endpoint,
|
||||||
|
Region: req.Region,
|
||||||
|
Bucket: req.Bucket,
|
||||||
|
AccessKeyID: req.AccessKeyID,
|
||||||
|
SecretAccessKey: req.SecretAccessKey,
|
||||||
|
Prefix: req.Prefix,
|
||||||
|
ForcePathStyle: req.ForcePathStyle,
|
||||||
|
UseSSL: req.UseSSL,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
|
||||||
|
var req CreateBackupJobRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
triggeredBy := "admin:unknown"
|
||||||
|
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||||
|
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||||
|
}
|
||||||
|
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
|
||||||
|
BackupType: req.BackupType,
|
||||||
|
UploadToS3: req.UploadToS3,
|
||||||
|
S3ProfileID: req.S3ProfileID,
|
||||||
|
PostgresID: req.PostgresID,
|
||||||
|
RedisID: req.RedisID,
|
||||||
|
TriggeredBy: triggeredBy,
|
||||||
|
IdempotencyKey: req.IdempotencyKey,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
|
||||||
|
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||||
|
if sourceType == "" {
|
||||||
|
response.BadRequest(c, "Invalid source_type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if sourceType != "postgres" && sourceType != "redis" {
|
||||||
|
response.BadRequest(c, "source_type must be postgres or redis")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"items": items})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
|
||||||
|
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||||
|
if sourceType != "postgres" && sourceType != "redis" {
|
||||||
|
response.BadRequest(c, "source_type must be postgres or redis")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req CreateSourceProfileRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
|
||||||
|
SourceType: sourceType,
|
||||||
|
ProfileID: req.ProfileID,
|
||||||
|
Name: req.Name,
|
||||||
|
Config: req.Config,
|
||||||
|
SetActive: req.SetActive,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
|
||||||
|
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||||
|
if sourceType != "postgres" && sourceType != "redis" {
|
||||||
|
response.BadRequest(c, "source_type must be postgres or redis")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Invalid profile_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateSourceProfileRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
|
||||||
|
SourceType: sourceType,
|
||||||
|
ProfileID: profileID,
|
||||||
|
Name: req.Name,
|
||||||
|
Config: req.Config,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
|
||||||
|
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||||
|
if sourceType != "postgres" && sourceType != "redis" {
|
||||||
|
response.BadRequest(c, "source_type must be postgres or redis")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Invalid profile_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"deleted": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
|
||||||
|
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||||
|
if sourceType != "postgres" && sourceType != "redis" {
|
||||||
|
response.BadRequest(c, "source_type must be postgres or redis")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Invalid profile_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"items": items})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
|
||||||
|
var req CreateS3ProfileRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
|
||||||
|
ProfileID: req.ProfileID,
|
||||||
|
Name: req.Name,
|
||||||
|
SetActive: req.SetActive,
|
||||||
|
S3: service.DataManagementS3Config{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
Endpoint: req.Endpoint,
|
||||||
|
Region: req.Region,
|
||||||
|
Bucket: req.Bucket,
|
||||||
|
AccessKeyID: req.AccessKeyID,
|
||||||
|
SecretAccessKey: req.SecretAccessKey,
|
||||||
|
Prefix: req.Prefix,
|
||||||
|
ForcePathStyle: req.ForcePathStyle,
|
||||||
|
UseSSL: req.UseSSL,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
|
||||||
|
var req UpdateS3ProfileRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Invalid profile_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
|
||||||
|
ProfileID: profileID,
|
||||||
|
Name: req.Name,
|
||||||
|
S3: service.DataManagementS3Config{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
Endpoint: req.Endpoint,
|
||||||
|
Region: req.Region,
|
||||||
|
Bucket: req.Bucket,
|
||||||
|
AccessKeyID: req.AccessKeyID,
|
||||||
|
SecretAccessKey: req.SecretAccessKey,
|
||||||
|
Prefix: req.Prefix,
|
||||||
|
ForcePathStyle: req.ForcePathStyle,
|
||||||
|
UseSSL: req.UseSSL,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Invalid profile_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"deleted": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Invalid profile_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pageSize := int32(20)
|
||||||
|
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
|
||||||
|
v, err := strconv.Atoi(raw)
|
||||||
|
if err != nil || v <= 0 {
|
||||||
|
response.BadRequest(c, "Invalid page_size")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pageSize = int32(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
|
||||||
|
PageSize: pageSize,
|
||||||
|
PageToken: c.Query("page_token"),
|
||||||
|
Status: c.Query("status"),
|
||||||
|
BackupType: c.Query("backup_type"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
|
||||||
|
jobID := strings.TrimSpace(c.Param("job_id"))
|
||||||
|
if jobID == "" {
|
||||||
|
response.BadRequest(c, "Invalid backup job ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.requireAgentEnabled(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, job)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
|
||||||
|
if h.dataManagementService == nil {
|
||||||
|
err := infraerrors.ServiceUnavailable(
|
||||||
|
service.DataManagementAgentUnavailableReason,
|
||||||
|
"data management agent service is not configured",
|
||||||
|
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
|
||||||
|
if h.dataManagementService == nil {
|
||||||
|
return service.DataManagementAgentHealth{
|
||||||
|
Enabled: false,
|
||||||
|
Reason: service.DataManagementAgentUnavailableReason,
|
||||||
|
SocketPath: service.DefaultDataManagementAgentSocketPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return h.dataManagementService.GetAgentHealth(c.Request.Context())
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
|
||||||
|
headerKey := strings.TrimSpace(headerValue)
|
||||||
|
if headerKey != "" {
|
||||||
|
return headerKey
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(bodyValue)
|
||||||
|
}
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiEnvelope struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
|
||||||
|
h := NewDataManagementHandler(svc)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var envelope apiEnvelope
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
|
||||||
|
require.Equal(t, 0, envelope.Code)
|
||||||
|
|
||||||
|
var data struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
SocketPath string `json:"socket_path"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(envelope.Data, &data))
|
||||||
|
require.False(t, data.Enabled)
|
||||||
|
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
|
||||||
|
require.Equal(t, svc.SocketPath(), data.SocketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
|
||||||
|
h := NewDataManagementHandler(svc)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||||
|
|
||||||
|
var envelope apiEnvelope
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
|
||||||
|
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
|
||||||
|
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
|
||||||
|
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
|
||||||
|
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
|
||||||
|
}
|
||||||
@@ -51,6 +51,8 @@ type CreateGroupRequest struct {
|
|||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
|
// Sora 存储配额
|
||||||
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||||
// 从指定分组复制账号(创建后自动绑定)
|
// 从指定分组复制账号(创建后自动绑定)
|
||||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||||
}
|
}
|
||||||
@@ -84,6 +86,8 @@ type UpdateGroupRequest struct {
|
|||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||||
|
// Sora 存储配额
|
||||||
|
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||||
}
|
}
|
||||||
@@ -198,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -248,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
@@ -47,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
|||||||
req = OpenAIGenerateAuthURLRequest{}
|
req = OpenAIGenerateAuthURLRequest{}
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
result, err := h.openaiOAuthService.GenerateAuthURL(
|
||||||
|
c.Request.Context(),
|
||||||
|
req.ProxyID,
|
||||||
|
req.RedirectURI,
|
||||||
|
oauthPlatformFromPath(c),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -123,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
|
// 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
|
||||||
|
clientID := strings.TrimSpace(req.ClientID)
|
||||||
|
if clientID == "" {
|
||||||
|
platform := oauthPlatformFromPath(c)
|
||||||
|
clientID, _ = openai.OAuthClientConfigByPlatform(platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -62,7 +62,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var wsConnCount atomic.Int32
|
var wsConnCount atomic.Int32
|
||||||
var wsConnCountByIP sync.Map // map[string]*atomic.Int32
|
var wsConnCountByIPMu sync.Mutex
|
||||||
|
var wsConnCountByIP = make(map[string]int32)
|
||||||
|
|
||||||
const qpsWSIdleStopDelay = 30 * time.Second
|
const qpsWSIdleStopDelay = 30 * time.Second
|
||||||
|
|
||||||
@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
|
|||||||
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
|
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
wsConnCountByIPMu.Lock()
|
||||||
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
|
defer wsConnCountByIPMu.Unlock()
|
||||||
counter, ok := v.(*atomic.Int32)
|
current := wsConnCountByIP[clientIP]
|
||||||
if !ok {
|
if current >= limit {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
wsConnCountByIP[clientIP] = current + 1
|
||||||
for {
|
return true
|
||||||
current := counter.Load()
|
|
||||||
if current >= limit {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if counter.CompareAndSwap(current, current+1) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func releaseOpsWSIPSlot(clientIP string) {
|
func releaseOpsWSIPSlot(clientIP string) {
|
||||||
if strings.TrimSpace(clientIP) == "" {
|
if strings.TrimSpace(clientIP) == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
wsConnCountByIPMu.Lock()
|
||||||
v, ok := wsConnCountByIP.Load(clientIP)
|
defer wsConnCountByIPMu.Unlock()
|
||||||
|
current, ok := wsConnCountByIP[clientIP]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
counter, ok := v.(*atomic.Int32)
|
if current <= 1 {
|
||||||
if !ok {
|
delete(wsConnCountByIP, clientIP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next := counter.Add(-1)
|
wsConnCountByIP[clientIP] = current - 1
|
||||||
if next <= 0 {
|
|
||||||
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
|
|
||||||
wsConnCountByIP.Delete(clientIP)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,15 +21,17 @@ type SettingHandler struct {
|
|||||||
emailService *service.EmailService
|
emailService *service.EmailService
|
||||||
turnstileService *service.TurnstileService
|
turnstileService *service.TurnstileService
|
||||||
opsService *service.OpsService
|
opsService *service.OpsService
|
||||||
|
soraS3Storage *service.SoraS3Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSettingHandler 创建系统设置处理器
|
// NewSettingHandler 创建系统设置处理器
|
||||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
|
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
|
||||||
return &SettingHandler{
|
return &SettingHandler{
|
||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
emailService: emailService,
|
emailService: emailService,
|
||||||
turnstileService: turnstileService,
|
turnstileService: turnstileService,
|
||||||
opsService: opsService,
|
opsService: opsService,
|
||||||
|
soraS3Storage: soraS3Storage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,6 +79,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
HideCcsImportButton: settings.HideCcsImportButton,
|
HideCcsImportButton: settings.HideCcsImportButton,
|
||||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||||
|
SoraClientEnabled: settings.SoraClientEnabled,
|
||||||
DefaultConcurrency: settings.DefaultConcurrency,
|
DefaultConcurrency: settings.DefaultConcurrency,
|
||||||
DefaultBalance: settings.DefaultBalance,
|
DefaultBalance: settings.DefaultBalance,
|
||||||
EnableModelFallback: settings.EnableModelFallback,
|
EnableModelFallback: settings.EnableModelFallback,
|
||||||
@@ -133,6 +137,7 @@ type UpdateSettingsRequest struct {
|
|||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||||
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
|
|
||||||
// 默认配置
|
// 默认配置
|
||||||
DefaultConcurrency int `json:"default_concurrency"`
|
DefaultConcurrency int `json:"default_concurrency"`
|
||||||
@@ -319,6 +324,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
HideCcsImportButton: req.HideCcsImportButton,
|
HideCcsImportButton: req.HideCcsImportButton,
|
||||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||||
PurchaseSubscriptionURL: purchaseURL,
|
PurchaseSubscriptionURL: purchaseURL,
|
||||||
|
SoraClientEnabled: req.SoraClientEnabled,
|
||||||
DefaultConcurrency: req.DefaultConcurrency,
|
DefaultConcurrency: req.DefaultConcurrency,
|
||||||
DefaultBalance: req.DefaultBalance,
|
DefaultBalance: req.DefaultBalance,
|
||||||
EnableModelFallback: req.EnableModelFallback,
|
EnableModelFallback: req.EnableModelFallback,
|
||||||
@@ -400,6 +406,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||||
|
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
||||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||||
DefaultBalance: updatedSettings.DefaultBalance,
|
DefaultBalance: updatedSettings.DefaultBalance,
|
||||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||||
@@ -750,6 +757,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
|
||||||
|
if settings == nil {
|
||||||
|
return dto.SoraS3Settings{}
|
||||||
|
}
|
||||||
|
return dto.SoraS3Settings{
|
||||||
|
Enabled: settings.Enabled,
|
||||||
|
Endpoint: settings.Endpoint,
|
||||||
|
Region: settings.Region,
|
||||||
|
Bucket: settings.Bucket,
|
||||||
|
AccessKeyID: settings.AccessKeyID,
|
||||||
|
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
|
||||||
|
Prefix: settings.Prefix,
|
||||||
|
ForcePathStyle: settings.ForcePathStyle,
|
||||||
|
CDNURL: settings.CDNURL,
|
||||||
|
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
|
||||||
|
return dto.SoraS3Profile{
|
||||||
|
ProfileID: profile.ProfileID,
|
||||||
|
Name: profile.Name,
|
||||||
|
IsActive: profile.IsActive,
|
||||||
|
Enabled: profile.Enabled,
|
||||||
|
Endpoint: profile.Endpoint,
|
||||||
|
Region: profile.Region,
|
||||||
|
Bucket: profile.Bucket,
|
||||||
|
AccessKeyID: profile.AccessKeyID,
|
||||||
|
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
|
||||||
|
Prefix: profile.Prefix,
|
||||||
|
ForcePathStyle: profile.ForcePathStyle,
|
||||||
|
CDNURL: profile.CDNURL,
|
||||||
|
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
|
||||||
|
UpdatedAt: profile.UpdatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
|
||||||
|
if !enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(endpoint) == "" {
|
||||||
|
return fmt.Errorf("S3 Endpoint is required when enabled")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(bucket) == "" {
|
||||||
|
return fmt.Errorf("S3 Bucket is required when enabled")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(accessKeyID) == "" {
|
||||||
|
return fmt.Errorf("S3 Access Key ID is required when enabled")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("S3 Secret Access Key is required when enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
|
||||||
|
for idx := range items {
|
||||||
|
if items[idx].ProfileID == profileID {
|
||||||
|
return &items[idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
|
||||||
|
// GET /api/v1/admin/settings/sora-s3
|
||||||
|
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
|
||||||
|
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, toSoraS3SettingsDTO(settings))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSoraS3Profiles 获取 Sora S3 多配置
|
||||||
|
// GET /api/v1/admin/settings/sora-s3/profiles
|
||||||
|
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
|
||||||
|
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
items := make([]dto.SoraS3Profile, 0, len(result.Items))
|
||||||
|
for idx := range result.Items {
|
||||||
|
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
|
||||||
|
}
|
||||||
|
response.Success(c, dto.ListSoraS3ProfilesResponse{
|
||||||
|
ActiveProfileID: result.ActiveProfileID,
|
||||||
|
Items: items,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
|
||||||
|
type UpdateSoraS3SettingsRequest struct {
|
||||||
|
ProfileID string `json:"profile_id"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKey string `json:"secret_access_key"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
CDNURL string `json:"cdn_url"`
|
||||||
|
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateSoraS3ProfileRequest struct {
|
||||||
|
ProfileID string `json:"profile_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
SetActive bool `json:"set_active"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKey string `json:"secret_access_key"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
CDNURL string `json:"cdn_url"`
|
||||||
|
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateSoraS3ProfileRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKey string `json:"secret_access_key"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
CDNURL string `json:"cdn_url"`
|
||||||
|
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSoraS3Profile 创建 Sora S3 配置
|
||||||
|
// POST /api/v1/admin/settings/sora-s3/profiles
|
||||||
|
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
|
||||||
|
var req CreateSoraS3ProfileRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.DefaultStorageQuotaBytes < 0 {
|
||||||
|
req.DefaultStorageQuotaBytes = 0
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Name) == "" {
|
||||||
|
response.BadRequest(c, "Name is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.ProfileID) == "" {
|
||||||
|
response.BadRequest(c, "Profile ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
|
||||||
|
ProfileID: req.ProfileID,
|
||||||
|
Name: req.Name,
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
Endpoint: req.Endpoint,
|
||||||
|
Region: req.Region,
|
||||||
|
Bucket: req.Bucket,
|
||||||
|
AccessKeyID: req.AccessKeyID,
|
||||||
|
SecretAccessKey: req.SecretAccessKey,
|
||||||
|
Prefix: req.Prefix,
|
||||||
|
ForcePathStyle: req.ForcePathStyle,
|
||||||
|
CDNURL: req.CDNURL,
|
||||||
|
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||||
|
}, req.SetActive)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, toSoraS3ProfileDTO(*created))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraS3Profile 更新 Sora S3 配置
|
||||||
|
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||||
|
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Profile ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateSoraS3ProfileRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.DefaultStorageQuotaBytes < 0 {
|
||||||
|
req.DefaultStorageQuotaBytes = 0
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Name) == "" {
|
||||||
|
response.BadRequest(c, "Name is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
existing := findSoraS3ProfileByID(existingList.Items, profileID)
|
||||||
|
if existing == nil {
|
||||||
|
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
|
||||||
|
Name: req.Name,
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
Endpoint: req.Endpoint,
|
||||||
|
Region: req.Region,
|
||||||
|
Bucket: req.Bucket,
|
||||||
|
AccessKeyID: req.AccessKeyID,
|
||||||
|
SecretAccessKey: req.SecretAccessKey,
|
||||||
|
Prefix: req.Prefix,
|
||||||
|
ForcePathStyle: req.ForcePathStyle,
|
||||||
|
CDNURL: req.CDNURL,
|
||||||
|
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||||
|
})
|
||||||
|
if updateErr != nil {
|
||||||
|
response.ErrorFrom(c, updateErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, toSoraS3ProfileDTO(*updated))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSoraS3Profile 删除 Sora S3 配置
|
||||||
|
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||||
|
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Profile ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"deleted": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
|
||||||
|
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
|
||||||
|
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
|
||||||
|
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||||
|
if profileID == "" {
|
||||||
|
response.BadRequest(c, "Profile ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, toSoraS3ProfileDTO(*active))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
|
||||||
|
// PUT /api/v1/admin/settings/sora-s3
|
||||||
|
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
|
||||||
|
var req UpdateSoraS3SettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.DefaultStorageQuotaBytes < 0 {
|
||||||
|
req.DefaultStorageQuotaBytes = 0
|
||||||
|
}
|
||||||
|
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &service.SoraS3Settings{
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
Endpoint: req.Endpoint,
|
||||||
|
Region: req.Region,
|
||||||
|
Bucket: req.Bucket,
|
||||||
|
AccessKeyID: req.AccessKeyID,
|
||||||
|
SecretAccessKey: req.SecretAccessKey,
|
||||||
|
Prefix: req.Prefix,
|
||||||
|
ForcePathStyle: req.ForcePathStyle,
|
||||||
|
CDNURL: req.CDNURL,
|
||||||
|
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||||
|
}
|
||||||
|
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
|
||||||
|
// POST /api/v1/admin/settings/sora-s3/test
|
||||||
|
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||||
|
if h.soraS3Storage == nil {
|
||||||
|
response.Error(c, 500, "S3 存储服务未初始化")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateSoraS3SettingsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !req.Enabled {
|
||||||
|
response.BadRequest(c, "S3 未启用,无法测试连接")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.SecretAccessKey == "" {
|
||||||
|
if req.ProfileID != "" {
|
||||||
|
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||||
|
if err == nil {
|
||||||
|
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
|
||||||
|
if profile != nil {
|
||||||
|
req.SecretAccessKey = profile.SecretAccessKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.SecretAccessKey == "" {
|
||||||
|
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||||
|
if err == nil {
|
||||||
|
req.SecretAccessKey = existing.SecretAccessKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCfg := &service.SoraS3Settings{
|
||||||
|
Enabled: true,
|
||||||
|
Endpoint: req.Endpoint,
|
||||||
|
Region: req.Region,
|
||||||
|
Bucket: req.Bucket,
|
||||||
|
AccessKeyID: req.AccessKeyID,
|
||||||
|
SecretAccessKey: req.SecretAccessKey,
|
||||||
|
Prefix: req.Prefix,
|
||||||
|
ForcePathStyle: req.ForcePathStyle,
|
||||||
|
CDNURL: req.CDNURL,
|
||||||
|
}
|
||||||
|
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
|
||||||
|
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||||
type UpdateStreamTimeoutSettingsRequest struct {
|
type UpdateStreamTimeoutSettingsRequest struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|||||||
@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
|
||||||
|
repo := &cleanupRepoStub{}
|
||||||
|
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||||
|
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||||
|
router := setupCleanupRouter(cleanupService, 88)
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"start_date": "2024-01-01",
|
||||||
|
"end_date": "2024-01-02",
|
||||||
|
"timezone": "UTC",
|
||||||
|
"request_type": "invalid",
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
|
||||||
|
repo := &cleanupRepoStub{}
|
||||||
|
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||||
|
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||||
|
router := setupCleanupRouter(cleanupService, 99)
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"start_date": "2024-01-01",
|
||||||
|
"end_date": "2024-01-02",
|
||||||
|
"timezone": "UTC",
|
||||||
|
"request_type": "ws_v2",
|
||||||
|
"stream": false,
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
repo.mu.Lock()
|
||||||
|
defer repo.mu.Unlock()
|
||||||
|
require.Len(t, repo.created, 1)
|
||||||
|
created := repo.created[0]
|
||||||
|
require.NotNil(t, created.Filters.RequestType)
|
||||||
|
require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
|
||||||
|
require.Nil(t, created.Filters.Stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
|
||||||
|
repo := &cleanupRepoStub{}
|
||||||
|
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||||
|
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||||
|
router := setupCleanupRouter(cleanupService, 99)
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"start_date": "2024-01-01",
|
||||||
|
"end_date": "2024-01-02",
|
||||||
|
"timezone": "UTC",
|
||||||
|
"stream": true,
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
repo.mu.Lock()
|
||||||
|
defer repo.mu.Unlock()
|
||||||
|
require.Len(t, repo.created, 1)
|
||||||
|
created := repo.created[0]
|
||||||
|
require.Nil(t, created.Filters.RequestType)
|
||||||
|
require.NotNil(t, created.Filters.Stream)
|
||||||
|
require.True(t, *created.Filters.Stream)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
|
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
|
||||||
repo := &cleanupRepoStub{}
|
repo := &cleanupRepoStub{}
|
||||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct {
|
|||||||
AccountID *int64 `json:"account_id"`
|
AccountID *int64 `json:"account_id"`
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
Model *string `json:"model"`
|
Model *string `json:"model"`
|
||||||
|
RequestType *string `json:"request_type"`
|
||||||
Stream *bool `json:"stream"`
|
Stream *bool `json:"stream"`
|
||||||
BillingType *int8 `json:"billing_type"`
|
BillingType *int8 `json:"billing_type"`
|
||||||
Timezone string `json:"timezone"`
|
Timezone string `json:"timezone"`
|
||||||
@@ -101,8 +102,17 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
|
|
||||||
model := c.Query("model")
|
model := c.Query("model")
|
||||||
|
|
||||||
|
var requestType *int16
|
||||||
var stream *bool
|
var stream *bool
|
||||||
if streamStr := c.Query("stream"); streamStr != "" {
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := int16(parsed)
|
||||||
|
requestType = &value
|
||||||
|
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
val, err := strconv.ParseBool(streamStr)
|
val, err := strconv.ParseBool(streamStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||||
@@ -152,6 +162,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
Model: model,
|
Model: model,
|
||||||
|
RequestType: requestType,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
@@ -214,8 +225,17 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
|
|
||||||
model := c.Query("model")
|
model := c.Query("model")
|
||||||
|
|
||||||
|
var requestType *int16
|
||||||
var stream *bool
|
var stream *bool
|
||||||
if streamStr := c.Query("stream"); streamStr != "" {
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := int16(parsed)
|
||||||
|
requestType = &value
|
||||||
|
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
val, err := strconv.ParseBool(streamStr)
|
val, err := strconv.ParseBool(streamStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||||
@@ -278,6 +298,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
|||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
Model: model,
|
Model: model,
|
||||||
|
RequestType: requestType,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
StartTime: &startTime,
|
StartTime: &startTime,
|
||||||
@@ -432,6 +453,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||||
|
|
||||||
|
var requestType *int16
|
||||||
|
stream := req.Stream
|
||||||
|
if req.RequestType != nil {
|
||||||
|
parsed, err := service.ParseUsageRequestType(*req.RequestType)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := int16(parsed)
|
||||||
|
requestType = &value
|
||||||
|
stream = nil
|
||||||
|
}
|
||||||
|
|
||||||
filters := service.UsageCleanupFilters{
|
filters := service.UsageCleanupFilters{
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
EndTime: endTime,
|
EndTime: endTime,
|
||||||
@@ -440,7 +474,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
|||||||
AccountID: req.AccountID,
|
AccountID: req.AccountID,
|
||||||
GroupID: req.GroupID,
|
GroupID: req.GroupID,
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Stream: req.Stream,
|
RequestType: requestType,
|
||||||
|
Stream: stream,
|
||||||
BillingType: req.BillingType,
|
BillingType: req.BillingType,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -464,9 +499,13 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
|||||||
if filters.Model != nil {
|
if filters.Model != nil {
|
||||||
model = *filters.Model
|
model = *filters.Model
|
||||||
}
|
}
|
||||||
var stream any
|
var streamValue any
|
||||||
if filters.Stream != nil {
|
if filters.Stream != nil {
|
||||||
stream = *filters.Stream
|
streamValue = *filters.Stream
|
||||||
|
}
|
||||||
|
var requestTypeName any
|
||||||
|
if filters.RequestType != nil {
|
||||||
|
requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String()
|
||||||
}
|
}
|
||||||
var billingType any
|
var billingType any
|
||||||
if filters.BillingType != nil {
|
if filters.BillingType != nil {
|
||||||
@@ -481,7 +520,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
|||||||
Body: req,
|
Body: req,
|
||||||
}
|
}
|
||||||
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q",
|
||||||
subject.UserID,
|
subject.UserID,
|
||||||
filters.StartTime.Format(time.RFC3339),
|
filters.StartTime.Format(time.RFC3339),
|
||||||
filters.EndTime.Format(time.RFC3339),
|
filters.EndTime.Format(time.RFC3339),
|
||||||
@@ -490,7 +529,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
|||||||
accountID,
|
accountID,
|
||||||
groupID,
|
groupID,
|
||||||
model,
|
model,
|
||||||
stream,
|
requestTypeName,
|
||||||
|
streamValue,
|
||||||
billingType,
|
billingType,
|
||||||
req.Timezone,
|
req.Timezone,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,117 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type adminUsageRepoCapture struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
listFilters usagestats.UsageLogFilters
|
||||||
|
statsFilters usagestats.UsageLogFilters
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
s.listFilters = filters
|
||||||
|
return []service.UsageLog{}, &pagination.PaginationResult{
|
||||||
|
Total: 0,
|
||||||
|
Page: params.Page,
|
||||||
|
PageSize: params.PageSize,
|
||||||
|
Pages: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||||
|
s.statsFilters = filters
|
||||||
|
return &usagestats.UsageStats{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
usageSvc := service.NewUsageService(repo, nil, nil, nil)
|
||||||
|
handler := NewUsageHandler(usageSvc, nil, nil, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/admin/usage", handler.List)
|
||||||
|
router.GET("/admin/usage/stats", handler.Stats)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUsageListRequestTypePriority(t *testing.T) {
|
||||||
|
repo := &adminUsageRepoCapture{}
|
||||||
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.NotNil(t, repo.listFilters.RequestType)
|
||||||
|
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
|
||||||
|
require.Nil(t, repo.listFilters.Stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUsageListInvalidRequestType(t *testing.T) {
|
||||||
|
repo := &adminUsageRepoCapture{}
|
||||||
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUsageListInvalidStream(t *testing.T) {
|
||||||
|
repo := &adminUsageRepoCapture{}
|
||||||
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
|
||||||
|
repo := &adminUsageRepoCapture{}
|
||||||
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.NotNil(t, repo.statsFilters.RequestType)
|
||||||
|
require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType)
|
||||||
|
require.Nil(t, repo.statsFilters.Stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUsageStatsInvalidRequestType(t *testing.T) {
|
||||||
|
repo := &adminUsageRepoCapture{}
|
||||||
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUsageStatsInvalidStream(t *testing.T) {
|
||||||
|
repo := &adminUsageRepoCapture{}
|
||||||
|
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
@@ -34,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
|
|||||||
|
|
||||||
// CreateUserRequest represents admin create user request
|
// CreateUserRequest represents admin create user request
|
||||||
type CreateUserRequest struct {
|
type CreateUserRequest struct {
|
||||||
Email string `json:"email" binding:"required,email"`
|
Email string `json:"email" binding:"required,email"`
|
||||||
Password string `json:"password" binding:"required,min=6"`
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
Balance float64 `json:"balance"`
|
Balance float64 `json:"balance"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
AllowedGroups []int64 `json:"allowed_groups"`
|
AllowedGroups []int64 `json:"allowed_groups"`
|
||||||
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUserRequest represents admin update user request
|
// UpdateUserRequest represents admin update user request
|
||||||
@@ -56,7 +57,8 @@ type UpdateUserRequest struct {
|
|||||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||||
// GroupRates 用户专属分组倍率配置
|
// GroupRates 用户专属分组倍率配置
|
||||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||||
|
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateBalanceRequest represents balance update request
|
// UpdateBalanceRequest represents balance update request
|
||||||
@@ -174,13 +176,14 @@ func (h *UserHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
AllowedGroups: req.AllowedGroups,
|
AllowedGroups: req.AllowedGroups,
|
||||||
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -207,15 +210,16 @@ func (h *UserHandler) Update(c *gin.Context) {
|
|||||||
|
|
||||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Password: req.Password,
|
Password: req.Password,
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
Balance: req.Balance,
|
Balance: req.Balance,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
AllowedGroups: req.AllowedGroups,
|
AllowedGroups: req.AllowedGroups,
|
||||||
GroupRates: req.GroupRates,
|
GroupRates: req.GroupRates,
|
||||||
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@@ -113,9 +113,8 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Turnstile 验证 — 始终执行,防止绕过
|
// Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token)
|
||||||
// TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
|
if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil {
|
||||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &AdminUser{
|
return &AdminUser{
|
||||||
User: *base,
|
User: *base,
|
||||||
Notes: u.Notes,
|
Notes: u.Notes,
|
||||||
GroupRates: u.GroupRates,
|
GroupRates: u.GroupRates,
|
||||||
|
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||||
|
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +154,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
|||||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
FallbackGroupID: g.FallbackGroupID,
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||||
|
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
@@ -385,6 +388,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
|||||||
|
|
||||||
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||||
|
requestType := l.EffectiveRequestType()
|
||||||
|
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
|
||||||
return UsageLog{
|
return UsageLog{
|
||||||
ID: l.ID,
|
ID: l.ID,
|
||||||
UserID: l.UserID,
|
UserID: l.UserID,
|
||||||
@@ -409,7 +414,9 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
|||||||
ActualCost: l.ActualCost,
|
ActualCost: l.ActualCost,
|
||||||
RateMultiplier: l.RateMultiplier,
|
RateMultiplier: l.RateMultiplier,
|
||||||
BillingType: l.BillingType,
|
BillingType: l.BillingType,
|
||||||
Stream: l.Stream,
|
RequestType: requestType.String(),
|
||||||
|
Stream: stream,
|
||||||
|
OpenAIWSMode: openAIWSMode,
|
||||||
DurationMs: l.DurationMs,
|
DurationMs: l.DurationMs,
|
||||||
FirstTokenMs: l.FirstTokenMs,
|
FirstTokenMs: l.FirstTokenMs,
|
||||||
ImageCount: l.ImageCount,
|
ImageCount: l.ImageCount,
|
||||||
@@ -464,6 +471,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
|
|||||||
AccountID: task.Filters.AccountID,
|
AccountID: task.Filters.AccountID,
|
||||||
GroupID: task.Filters.GroupID,
|
GroupID: task.Filters.GroupID,
|
||||||
Model: task.Filters.Model,
|
Model: task.Filters.Model,
|
||||||
|
RequestType: requestTypeStringPtr(task.Filters.RequestType),
|
||||||
Stream: task.Filters.Stream,
|
Stream: task.Filters.Stream,
|
||||||
BillingType: task.Filters.BillingType,
|
BillingType: task.Filters.BillingType,
|
||||||
},
|
},
|
||||||
@@ -479,6 +487,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requestTypeStringPtr(requestType *int16) *string {
|
||||||
|
if requestType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
value := service.RequestTypeFromInt16(*requestType).String()
|
||||||
|
return &value
|
||||||
|
}
|
||||||
|
|
||||||
func SettingFromService(s *service.Setting) *Setting {
|
func SettingFromService(s *service.Setting) *Setting {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
73
backend/internal/handler/dto/mappers_usage_test.go
Normal file
73
backend/internal/handler/dto/mappers_usage_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
wsLog := &service.UsageLog{
|
||||||
|
RequestID: "req_1",
|
||||||
|
Model: "gpt-5.3-codex",
|
||||||
|
OpenAIWSMode: true,
|
||||||
|
}
|
||||||
|
httpLog := &service.UsageLog{
|
||||||
|
RequestID: "resp_1",
|
||||||
|
Model: "gpt-5.3-codex",
|
||||||
|
OpenAIWSMode: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.True(t, UsageLogFromService(wsLog).OpenAIWSMode)
|
||||||
|
require.False(t, UsageLogFromService(httpLog).OpenAIWSMode)
|
||||||
|
require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode)
|
||||||
|
require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
log := &service.UsageLog{
|
||||||
|
RequestID: "req_2",
|
||||||
|
Model: "gpt-5.3-codex",
|
||||||
|
RequestType: service.RequestTypeWSV2,
|
||||||
|
Stream: false,
|
||||||
|
OpenAIWSMode: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
userDTO := UsageLogFromService(log)
|
||||||
|
adminDTO := UsageLogFromServiceAdmin(log)
|
||||||
|
|
||||||
|
require.Equal(t, "ws_v2", userDTO.RequestType)
|
||||||
|
require.True(t, userDTO.Stream)
|
||||||
|
require.True(t, userDTO.OpenAIWSMode)
|
||||||
|
require.Equal(t, "ws_v2", adminDTO.RequestType)
|
||||||
|
require.True(t, adminDTO.Stream)
|
||||||
|
require.True(t, adminDTO.OpenAIWSMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
requestType := int16(service.RequestTypeStream)
|
||||||
|
task := &service.UsageCleanupTask{
|
||||||
|
ID: 1,
|
||||||
|
Status: service.UsageCleanupStatusPending,
|
||||||
|
Filters: service.UsageCleanupFilters{
|
||||||
|
RequestType: &requestType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
dtoTask := UsageCleanupTaskFromService(task)
|
||||||
|
require.NotNil(t, dtoTask)
|
||||||
|
require.NotNil(t, dtoTask.Filters.RequestType)
|
||||||
|
require.Equal(t, "stream", *dtoTask.Filters.RequestType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.Nil(t, requestTypeStringPtr(nil))
|
||||||
|
}
|
||||||
@@ -37,6 +37,7 @@ type SystemSettings struct {
|
|||||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||||
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
|
|
||||||
DefaultConcurrency int `json:"default_concurrency"`
|
DefaultConcurrency int `json:"default_concurrency"`
|
||||||
DefaultBalance float64 `json:"default_balance"`
|
DefaultBalance float64 `json:"default_balance"`
|
||||||
@@ -79,9 +80,48 @@ type PublicSettings struct {
|
|||||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||||
|
type SoraS3Settings struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
CDNURL string `json:"cdn_url"`
|
||||||
|
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
|
||||||
|
type SoraS3Profile struct {
|
||||||
|
ProfileID string `json:"profile_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
IsActive bool `json:"is_active"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
CDNURL string `json:"cdn_url"`
|
||||||
|
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
|
||||||
|
type ListSoraS3ProfilesResponse struct {
|
||||||
|
ActiveProfileID string `json:"active_profile_id"`
|
||||||
|
Items []SoraS3Profile `json:"items"`
|
||||||
|
}
|
||||||
|
|
||||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||||
type StreamTimeoutSettings struct {
|
type StreamTimeoutSettings struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ type AdminUser struct {
|
|||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
// GroupRates 用户专属分组倍率配置
|
// GroupRates 用户专属分组倍率配置
|
||||||
// map[groupID]rateMultiplier
|
// map[groupID]rateMultiplier
|
||||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||||
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||||
|
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
@@ -80,6 +82,9 @@ type Group struct {
|
|||||||
// 无效请求兜底分组
|
// 无效请求兜底分组
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
|
|
||||||
|
// Sora 存储配额
|
||||||
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
@@ -278,10 +283,12 @@ type UsageLog struct {
|
|||||||
ActualCost float64 `json:"actual_cost"`
|
ActualCost float64 `json:"actual_cost"`
|
||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
|
|
||||||
BillingType int8 `json:"billing_type"`
|
BillingType int8 `json:"billing_type"`
|
||||||
Stream bool `json:"stream"`
|
RequestType string `json:"request_type"`
|
||||||
DurationMs *int `json:"duration_ms"`
|
Stream bool `json:"stream"`
|
||||||
FirstTokenMs *int `json:"first_token_ms"`
|
OpenAIWSMode bool `json:"openai_ws_mode"`
|
||||||
|
DurationMs *int `json:"duration_ms"`
|
||||||
|
FirstTokenMs *int `json:"first_token_ms"`
|
||||||
|
|
||||||
// 图片生成字段
|
// 图片生成字段
|
||||||
ImageCount int `json:"image_count"`
|
ImageCount int `json:"image_count"`
|
||||||
@@ -324,6 +331,7 @@ type UsageCleanupFilters struct {
|
|||||||
AccountID *int64 `json:"account_id,omitempty"`
|
AccountID *int64 `json:"account_id,omitempty"`
|
||||||
GroupID *int64 `json:"group_id,omitempty"`
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
Model *string `json:"model,omitempty"`
|
Model *string `json:"model,omitempty"`
|
||||||
|
RequestType *string `json:"request_type,omitempty"`
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
BillingType *int8 `json:"billing_type,omitempty"`
|
BillingType *int8 `json:"billing_type,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
||||||
@@ -78,8 +79,12 @@ func (s *FailoverState) HandleFailoverError(
|
|||||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||||
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
||||||
s.SameAccountRetryCount[accountID]++
|
s.SameAccountRetryCount[accountID]++
|
||||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
logger.FromContext(ctx).Warn("gateway.failover_same_account_retry",
|
||||||
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
|
zap.Int64("account_id", accountID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]),
|
||||||
|
zap.Int("same_account_retry_max", maxSameAccountRetries),
|
||||||
|
)
|
||||||
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
||||||
return FailoverCanceled
|
return FailoverCanceled
|
||||||
}
|
}
|
||||||
@@ -101,8 +106,12 @@ func (s *FailoverState) HandleFailoverError(
|
|||||||
|
|
||||||
// 递增切换计数
|
// 递增切换计数
|
||||||
s.SwitchCount++
|
s.SwitchCount++
|
||||||
log.Printf("Account %d: upstream error %d, switching account %d/%d",
|
logger.FromContext(ctx).Warn("gateway.failover_switch_account",
|
||||||
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
|
zap.Int64("account_id", accountID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("switch_count", s.SwitchCount),
|
||||||
|
zap.Int("max_switches", s.MaxSwitches),
|
||||||
|
)
|
||||||
|
|
||||||
// Antigravity 平台换号线性递增延时
|
// Antigravity 平台换号线性递增延时
|
||||||
if platform == service.PlatformAntigravity {
|
if platform == service.PlatformAntigravity {
|
||||||
@@ -127,13 +136,18 @@ func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAc
|
|||||||
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
||||||
s.SwitchCount <= s.MaxSwitches {
|
s.SwitchCount <= s.MaxSwitches {
|
||||||
|
|
||||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
|
logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff",
|
||||||
singleAccountBackoffDelay, s.SwitchCount)
|
zap.Duration("backoff_delay", singleAccountBackoffDelay),
|
||||||
|
zap.Int("switch_count", s.SwitchCount),
|
||||||
|
zap.Int("max_switches", s.MaxSwitches),
|
||||||
|
)
|
||||||
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
||||||
return FailoverCanceled
|
return FailoverCanceled
|
||||||
}
|
}
|
||||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
|
logger.FromContext(ctx).Warn("gateway.failover_single_account_retry",
|
||||||
s.SwitchCount, s.MaxSwitches)
|
zap.Int("switch_count", s.SwitchCount),
|
||||||
|
zap.Int("max_switches", s.MaxSwitches),
|
||||||
|
)
|
||||||
s.FailedAccountIDs = make(map[int64]struct{})
|
s.FailedAccountIDs = make(map[int64]struct{})
|
||||||
return FailoverContinue
|
return FailoverContinue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -17,6 +18,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
@@ -27,6 +29,10 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const gatewayCompatibilityMetricsLogInterval = 1024
|
||||||
|
|
||||||
|
var gatewayCompatibilityMetricsLogCounter atomic.Uint64
|
||||||
|
|
||||||
// GatewayHandler handles API gateway requests
|
// GatewayHandler handles API gateway requests
|
||||||
type GatewayHandler struct {
|
type GatewayHandler struct {
|
||||||
gatewayService *service.GatewayService
|
gatewayService *service.GatewayService
|
||||||
@@ -109,9 +115,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
zap.Int64("api_key_id", apiKey.ID),
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
zap.Any("group_id", apiKey.GroupID),
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
)
|
)
|
||||||
|
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
|
||||||
|
|
||||||
// 读取请求体
|
// 读取请求体
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
@@ -140,16 +147,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
|
ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
// 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。
|
||||||
SetClaudeCodeClientContext(c, body)
|
SetClaudeCodeClientContext(c, body, parsedReq)
|
||||||
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
|
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
|
||||||
|
|
||||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||||
|
|
||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
@@ -247,8 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if apiKey.GroupID != nil {
|
if apiKey.GroupID != nil {
|
||||||
prefetchedGroupID = *apiKey.GroupID
|
prefetchedGroupID = *apiKey.GroupID
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -261,7 +267,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,7 +281,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
continue
|
continue
|
||||||
case FailoverCanceled:
|
case FailoverCanceled:
|
||||||
@@ -364,7 +370,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||||
@@ -439,7 +445,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
|
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -458,7 +464,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
continue
|
continue
|
||||||
case FailoverCanceled:
|
case FailoverCanceled:
|
||||||
@@ -547,7 +553,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||||
@@ -956,20 +962,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
|||||||
// Stream already started, send error as SSE event then close
|
// Stream already started, send error as SSE event then close
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if ok {
|
if ok {
|
||||||
// Send error event in SSE format with proper JSON marshaling
|
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
|
||||||
errorData := map[string]any{
|
errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
|
||||||
"type": "error",
|
|
||||||
"error": map[string]string{
|
|
||||||
"type": errType,
|
|
||||||
"message": message,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal(errorData)
|
|
||||||
if err != nil {
|
|
||||||
_ = c.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
|
|
||||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
}
|
}
|
||||||
@@ -1024,9 +1018,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
zap.Int64("api_key_id", apiKey.ID),
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
zap.Any("group_id", apiKey.GroupID),
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
)
|
)
|
||||||
|
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
|
||||||
|
|
||||||
// 读取请求体
|
// 读取请求体
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
@@ -1041,9 +1036,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
|
||||||
SetClaudeCodeClientContext(c, body)
|
|
||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||||
@@ -1051,9 +1043,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。
|
||||||
|
SetClaudeCodeClientContext(c, body, parsedReq)
|
||||||
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
|
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
|
||||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||||
|
|
||||||
// 验证 model 必填
|
// 验证 model 必填
|
||||||
if parsedReq.Model == "" {
|
if parsedReq.Model == "" {
|
||||||
@@ -1217,24 +1211,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
|||||||
textDeltas = []string{"New", " Conversation"}
|
textDeltas = []string{"New", " Conversation"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build message_start event with proper JSON marshaling
|
// Build message_start event with fixed schema.
|
||||||
messageStart := map[string]any{
|
messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}`
|
||||||
"type": "message_start",
|
|
||||||
"message": map[string]any{
|
|
||||||
"id": msgID,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": model,
|
|
||||||
"content": []any{},
|
|
||||||
"stop_reason": nil,
|
|
||||||
"stop_sequence": nil,
|
|
||||||
"usage": map[string]int{
|
|
||||||
"input_tokens": 10,
|
|
||||||
"output_tokens": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
messageStartJSON, _ := json.Marshal(messageStart)
|
|
||||||
|
|
||||||
// Build events
|
// Build events
|
||||||
events := []string{
|
events := []string{
|
||||||
@@ -1244,31 +1222,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
|||||||
|
|
||||||
// Add text deltas
|
// Add text deltas
|
||||||
for _, text := range textDeltas {
|
for _, text := range textDeltas {
|
||||||
delta := map[string]any{
|
deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}`
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": map[string]string{
|
|
||||||
"type": "text_delta",
|
|
||||||
"text": text,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
deltaJSON, _ := json.Marshal(delta)
|
|
||||||
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add final events
|
// Add final events
|
||||||
messageDelta := map[string]any{
|
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}`
|
||||||
"type": "message_delta",
|
|
||||||
"delta": map[string]any{
|
|
||||||
"stop_reason": "end_turn",
|
|
||||||
"stop_sequence": nil,
|
|
||||||
},
|
|
||||||
"usage": map[string]int{
|
|
||||||
"input_tokens": 10,
|
|
||||||
"output_tokens": outputTokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
|
||||||
|
|
||||||
events = append(events,
|
events = append(events,
|
||||||
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||||||
@@ -1366,6 +1325,30 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
|||||||
return http.StatusForbidden, "billing_error", msg
|
return http.StatusForbidden, "billing_error", msg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *GatewayHandler) metadataBridgeEnabled() bool {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) {
|
||||||
|
if reqLog == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
metrics := service.SnapshotOpenAICompatibilityFallbackMetrics()
|
||||||
|
reqLog.Info("gateway.compatibility_fallback_metrics",
|
||||||
|
zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal),
|
||||||
|
zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit),
|
||||||
|
zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal),
|
||||||
|
zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate),
|
||||||
|
zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||||
if task == nil {
|
if task == nil {
|
||||||
return
|
return
|
||||||
@@ -1377,5 +1360,13 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
|||||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
defer func() {
|
||||||
|
if recovered := recover(); recovered != nil {
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "handler.gateway.messages"),
|
||||||
|
zap.Any("panic", recovered),
|
||||||
|
).Error("gateway.usage_record_task_panic_recovered")
|
||||||
|
}
|
||||||
|
}()
|
||||||
task(ctx)
|
task(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,6 +119,13 @@ func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.A
|
|||||||
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||||
return map[int64]*service.UserLoadInfo{}, nil
|
return map[int64]*service.UserLoadInfo{}, nil
|
||||||
}
|
}
|
||||||
|
func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||||
|
result := make(map[int64]int, len(accountIDs))
|
||||||
|
for _, id := range accountIDs {
|
||||||
|
result[id] = 0
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||||
|
|
||||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||||
|
|||||||
@@ -18,12 +18,17 @@ import (
|
|||||||
// claudeCodeValidator is a singleton validator for Claude Code client detection
|
// claudeCodeValidator is a singleton validator for Claude Code client detection
|
||||||
var claudeCodeValidator = service.NewClaudeCodeValidator()
|
var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||||
|
|
||||||
|
const claudeCodeParsedRequestContextKey = "claude_code_parsed_request"
|
||||||
|
|
||||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||||
// 返回更新后的 context
|
// 返回更新后的 context
|
||||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) {
|
||||||
if c == nil || c.Request == nil {
|
if c == nil || c.Request == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if parsedReq != nil {
|
||||||
|
c.Set(claudeCodeParsedRequestContextKey, parsedReq)
|
||||||
|
}
|
||||||
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
|
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
|
||||||
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
|
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
|
||||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
|
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
|
||||||
@@ -37,8 +42,11 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
|||||||
isClaudeCode = true
|
isClaudeCode = true
|
||||||
} else {
|
} else {
|
||||||
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
|
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
|
||||||
var bodyMap map[string]any
|
bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq)
|
||||||
if len(body) > 0 {
|
if bodyMap == nil {
|
||||||
|
bodyMap = claudeCodeBodyMapFromContextCache(c)
|
||||||
|
}
|
||||||
|
if bodyMap == nil && len(body) > 0 {
|
||||||
_ = json.Unmarshal(body, &bodyMap)
|
_ = json.Unmarshal(body, &bodyMap)
|
||||||
}
|
}
|
||||||
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
|
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||||
@@ -49,6 +57,42 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
|||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any {
|
||||||
|
if parsedReq == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bodyMap := map[string]any{
|
||||||
|
"model": parsedReq.Model,
|
||||||
|
}
|
||||||
|
if parsedReq.System != nil || parsedReq.HasSystem {
|
||||||
|
bodyMap["system"] = parsedReq.System
|
||||||
|
}
|
||||||
|
if parsedReq.MetadataUserID != "" {
|
||||||
|
bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID}
|
||||||
|
}
|
||||||
|
return bodyMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok {
|
||||||
|
if bodyMap, ok := cached.(map[string]any); ok {
|
||||||
|
return bodyMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok {
|
||||||
|
switch v := cached.(type) {
|
||||||
|
case *service.ParsedRequest:
|
||||||
|
return claudeCodeBodyMapFromParsedRequest(v)
|
||||||
|
case service.ParsedRequest:
|
||||||
|
return claudeCodeBodyMapFromParsedRequest(&v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// 并发槽位等待相关常量
|
// 并发槽位等待相关常量
|
||||||
//
|
//
|
||||||
// 性能优化说明:
|
// 性能优化说明:
|
||||||
|
|||||||
@@ -33,6 +33,14 @@ func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accoun
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||||
|
result := make(map[int64]int, len(accountIDs))
|
||||||
|
for _, accountID := range accountIDs {
|
||||||
|
result[accountID] = 0
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,14 @@ func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context,
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||||
|
out := make(map[int64]int, len(accountIDs))
|
||||||
|
for _, accountID := range accountIDs {
|
||||||
|
out[accountID] = 0
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@@ -133,7 +141,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
|||||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
||||||
|
|
||||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
|
||||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -141,7 +149,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
|||||||
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
||||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||||
|
|
||||||
SetClaudeCodeClientContext(c, nil)
|
SetClaudeCodeClientContext(c, nil, nil)
|
||||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -152,7 +160,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
|||||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
|
||||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
|
||||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -160,11 +168,51 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
|||||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||||
// 缺少严格校验所需 header + body 字段
|
// 缺少严格校验所需 header + body 字段
|
||||||
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
|
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil)
|
||||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) {
|
||||||
|
t.Run("reuse parsed request without body unmarshal", func(t *testing.T) {
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||||
|
c.Request.Header.Set("X-App", "claude-code")
|
||||||
|
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||||
|
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
|
||||||
|
parsedReq := &service.ParsedRequest{
|
||||||
|
Model: "claude-3-5-sonnet-20241022",
|
||||||
|
System: []any{
|
||||||
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
|
},
|
||||||
|
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||||
|
SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq)
|
||||||
|
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reuse context cache without body unmarshal", func(t *testing.T) {
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||||
|
c.Request.Header.Set("X-App", "claude-code")
|
||||||
|
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||||
|
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{
|
||||||
|
"model": "claude-3-5-sonnet-20241022",
|
||||||
|
"system": []any{
|
||||||
|
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||||
|
},
|
||||||
|
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
||||||
|
})
|
||||||
|
|
||||||
|
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||||
|
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||||
cache := &helperConcurrencyCacheStub{
|
cache := &helperConcurrencyCacheStub{
|
||||||
accountSeq: []bool{false, true},
|
accountSeq: []bool{false, true},
|
||||||
|
|||||||
@@ -7,16 +7,15 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
@@ -168,7 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
stream := action == "streamGenerateContent"
|
stream := action == "streamGenerateContent"
|
||||||
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
|
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
|
||||||
|
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
@@ -268,8 +267,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
if apiKey.GroupID != nil {
|
if apiKey.GroupID != nil {
|
||||||
prefetchedGroupID = *apiKey.GroupID
|
prefetchedGroupID = *apiKey.GroupID
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -349,7 +347,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,7 +361,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
continue
|
continue
|
||||||
case FailoverCanceled:
|
case FailoverCanceled:
|
||||||
@@ -456,7 +454,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ type AdminHandlers struct {
|
|||||||
Group *admin.GroupHandler
|
Group *admin.GroupHandler
|
||||||
Account *admin.AccountHandler
|
Account *admin.AccountHandler
|
||||||
Announcement *admin.AnnouncementHandler
|
Announcement *admin.AnnouncementHandler
|
||||||
|
DataManagement *admin.DataManagementHandler
|
||||||
OAuth *admin.OAuthHandler
|
OAuth *admin.OAuthHandler
|
||||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||||
GeminiOAuth *admin.GeminiOAuthHandler
|
GeminiOAuth *admin.GeminiOAuthHandler
|
||||||
@@ -40,6 +41,7 @@ type Handlers struct {
|
|||||||
Gateway *GatewayHandler
|
Gateway *GatewayHandler
|
||||||
OpenAIGateway *OpenAIGatewayHandler
|
OpenAIGateway *OpenAIGatewayHandler
|
||||||
SoraGateway *SoraGatewayHandler
|
SoraGateway *SoraGatewayHandler
|
||||||
|
SoraClient *SoraClientHandler
|
||||||
Setting *SettingHandler
|
Setting *SettingHandler
|
||||||
Totp *TotpHandler
|
Totp *TotpHandler
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,17 +5,20 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime/debug"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -64,6 +67,11 @@ func NewOpenAIGatewayHandler(
|
|||||||
// Responses handles OpenAI Responses API endpoint
|
// Responses handles OpenAI Responses API endpoint
|
||||||
// POST /openai/v1/responses
|
// POST /openai/v1/responses
|
||||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||||
|
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||||
|
streamStarted := false
|
||||||
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
setOpenAIClientTransportHTTP(c)
|
||||||
|
|
||||||
requestStart := time.Now()
|
requestStart := time.Now()
|
||||||
|
|
||||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||||
@@ -85,9 +93,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
zap.Int64("api_key_id", apiKey.ID),
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
zap.Any("group_id", apiKey.GroupID),
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
)
|
)
|
||||||
|
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Read request body
|
// Read request body
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
@@ -125,43 +136,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
reqStream := streamResult.Bool()
|
reqStream := streamResult.Bool()
|
||||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||||
|
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
|
||||||
|
if previousResponseID != "" {
|
||||||
|
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||||||
|
reqLog = reqLog.With(
|
||||||
|
zap.Bool("has_previous_response_id", true),
|
||||||
|
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||||
|
zap.Int("previous_response_id_len", len(previousResponseID)),
|
||||||
|
)
|
||||||
|
if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||||||
|
reqLog.Warn("openai.request_validation_failed",
|
||||||
|
zap.String("reason", "previous_response_id_looks_like_message_id"),
|
||||||
|
)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
|
|
||||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||||
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||||
// 或带 id 且与 call_id 匹配的 item_reference。
|
return
|
||||||
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
|
|
||||||
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
|
||||||
var reqBody map[string]any
|
|
||||||
if err := json.Unmarshal(body, &reqBody); err == nil {
|
|
||||||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
|
||||||
if service.HasFunctionCallOutput(reqBody) {
|
|
||||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
|
||||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
|
||||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
|
||||||
reqLog.Warn("openai.request_validation_failed",
|
|
||||||
zap.String("reason", "function_call_output_missing_call_id"),
|
|
||||||
)
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
|
||||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
|
||||||
reqLog.Warn("openai.request_validation_failed",
|
|
||||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
|
||||||
)
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track if we've started streaming (for error handling)
|
|
||||||
streamStarted := false
|
|
||||||
|
|
||||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||||
if h.errorPassthroughService != nil {
|
if h.errorPassthroughService != nil {
|
||||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
@@ -173,51 +171,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||||
routingStart := time.Now()
|
routingStart := time.Now()
|
||||||
|
|
||||||
// 0. 先尝试直接抢占用户槽位(快速路径)
|
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
|
if !acquired {
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
|
||||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
waitCounted := false
|
|
||||||
if !userAcquired {
|
|
||||||
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
|
|
||||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
|
||||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
|
||||||
if waitErr != nil {
|
|
||||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
|
||||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
|
||||||
} else if !canWait {
|
|
||||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
|
||||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if waitErr == nil && canWait {
|
|
||||||
waitCounted = true
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if waitCounted {
|
|
||||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
|
||||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用户槽位已获取:退出等待队列计数。
|
|
||||||
if waitCounted {
|
|
||||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
|
||||||
waitCounted = false
|
|
||||||
}
|
|
||||||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
|
||||||
if userReleaseFunc != nil {
|
if userReleaseFunc != nil {
|
||||||
defer userReleaseFunc()
|
defer userReleaseFunc()
|
||||||
}
|
}
|
||||||
@@ -241,7 +199,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
// Select account supporting the requested model
|
// Select account supporting the requested model
|
||||||
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
previousResponseID,
|
||||||
|
sessionHash,
|
||||||
|
reqModel,
|
||||||
|
failedAccountIDs,
|
||||||
|
service.OpenAIUpstreamTransportAny,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
reqLog.Warn("openai.account_select_failed",
|
reqLog.Warn("openai.account_select_failed",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
@@ -258,80 +224,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if previousResponseID != "" && selection != nil && selection.Account != nil {
|
||||||
|
reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
|
||||||
|
}
|
||||||
|
reqLog.Debug("openai.account_schedule_decision",
|
||||||
|
zap.String("layer", scheduleDecision.Layer),
|
||||||
|
zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
|
||||||
|
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
|
||||||
|
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||||
|
zap.Int("top_k", scheduleDecision.TopK),
|
||||||
|
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
|
||||||
|
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||||
|
)
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
|
|
||||||
// 3. Acquire account concurrency slot
|
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||||
accountReleaseFunc := selection.ReleaseFunc
|
if !acquired {
|
||||||
if !selection.Acquired {
|
return
|
||||||
if selection.WaitPlan == nil {
|
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
|
|
||||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
|
||||||
c.Request.Context(),
|
|
||||||
account.ID,
|
|
||||||
selection.WaitPlan.MaxConcurrency,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if fastAcquired {
|
|
||||||
accountReleaseFunc = fastReleaseFunc
|
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
|
||||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
accountWaitCounted := false
|
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
|
||||||
} else if !canWait {
|
|
||||||
reqLog.Info("openai.account_wait_queue_full",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
|
||||||
)
|
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err == nil && canWait {
|
|
||||||
accountWaitCounted = true
|
|
||||||
}
|
|
||||||
releaseWait := func() {
|
|
||||||
if accountWaitCounted {
|
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
|
||||||
accountWaitCounted = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
|
||||||
c,
|
|
||||||
account.ID,
|
|
||||||
selection.WaitPlan.MaxConcurrency,
|
|
||||||
selection.WaitPlan.Timeout,
|
|
||||||
reqStream,
|
|
||||||
&streamStarted,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
|
||||||
releaseWait()
|
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Slot acquired: no longer waiting in queue.
|
|
||||||
releaseWait()
|
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
|
||||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
|
||||||
|
|
||||||
// Forward request
|
// Forward request
|
||||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
@@ -353,6 +269,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
lastFailoverErr = failoverErr
|
lastFailoverErr = failoverErr
|
||||||
if switchCount >= maxAccountSwitches {
|
if switchCount >= maxAccountSwitches {
|
||||||
@@ -368,14 +286,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("openai.forward_failed",
|
fields := []zap.Field{
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
)
|
}
|
||||||
|
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||||
|
reqLog.Warn("openai.forward_failed", fields...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog.Error("openai.forward_failed", fields...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if result != nil {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
|
} else {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
@@ -411,6 +340,525 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||||
|
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBody map[string]any
|
||||||
|
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||||
|
// 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||||
|
validation := service.ValidateFunctionCallOutputContext(reqBody)
|
||||||
|
if !validation.HasFunctionCallOutput {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||||
|
if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if validation.HasFunctionCallOutputMissingCallID {
|
||||||
|
reqLog.Warn("openai.request_validation_failed",
|
||||||
|
zap.String("reason", "function_call_output_missing_call_id"),
|
||||||
|
)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if validation.HasItemReferenceForAllCallIDs {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
reqLog.Warn("openai.request_validation_failed",
|
||||||
|
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||||
|
)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
|
||||||
|
c *gin.Context,
|
||||||
|
userID int64,
|
||||||
|
userConcurrency int,
|
||||||
|
reqStream bool,
|
||||||
|
streamStarted *bool,
|
||||||
|
reqLog *zap.Logger,
|
||||||
|
) (func(), bool) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||||
|
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if userAcquired {
|
||||||
|
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||||||
|
}
|
||||||
|
|
||||||
|
maxWait := service.CalculateMaxWait(userConcurrency)
|
||||||
|
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
|
||||||
|
if waitErr != nil {
|
||||||
|
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||||
|
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||||
|
} else if !canWait {
|
||||||
|
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||||
|
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
waitCounted := waitErr == nil && canWait
|
||||||
|
defer func() {
|
||||||
|
if waitCounted {
|
||||||
|
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||||
|
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 槽位获取成功后,立刻退出等待计数。
|
||||||
|
if waitCounted {
|
||||||
|
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||||||
|
waitCounted = false
|
||||||
|
}
|
||||||
|
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||||||
|
c *gin.Context,
|
||||||
|
groupID *int64,
|
||||||
|
sessionHash string,
|
||||||
|
selection *service.AccountSelectionResult,
|
||||||
|
reqStream bool,
|
||||||
|
streamStarted *bool,
|
||||||
|
reqLog *zap.Logger,
|
||||||
|
) (func(), bool) {
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
account := selection.Account
|
||||||
|
if selection.Acquired {
|
||||||
|
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
|
||||||
|
}
|
||||||
|
if selection.WaitPlan == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||||
|
ctx,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
|
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if fastAcquired {
|
||||||
|
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||||||
|
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
return wrapReleaseOnDone(ctx, fastReleaseFunc), true
|
||||||
|
}
|
||||||
|
|
||||||
|
canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
|
if waitErr != nil {
|
||||||
|
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
|
||||||
|
} else if !canWait {
|
||||||
|
reqLog.Info("openai.account_wait_queue_full",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||||
|
)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
accountWaitCounted := waitErr == nil && canWait
|
||||||
|
releaseWait := func() {
|
||||||
|
if accountWaitCounted {
|
||||||
|
h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer releaseWait()
|
||||||
|
|
||||||
|
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
|
c,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
selection.WaitPlan.Timeout,
|
||||||
|
reqStream,
|
||||||
|
streamStarted,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
|
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Slot acquired: no longer waiting in queue.
|
||||||
|
releaseWait()
|
||||||
|
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||||||
|
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
return wrapReleaseOnDone(ctx, accountReleaseFunc), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
|
||||||
|
// GET /openai/v1/responses (Upgrade: websocket)
|
||||||
|
func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||||
|
if !isOpenAIWSUpgradeRequest(c.Request) {
|
||||||
|
h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setOpenAIClientTransportWS(c)
|
||||||
|
|
||||||
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqLog := requestLogger(
|
||||||
|
c,
|
||||||
|
"handler.openai_gateway.responses_ws",
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
zap.Bool("openai_ws_mode", true),
|
||||||
|
)
|
||||||
|
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog.Info("openai.websocket_ingress_started")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
|
||||||
|
|
||||||
|
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
|
||||||
|
CompressionMode: coderws.CompressionContextTakeover,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.websocket_accept_failed",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("client_ip", clientIP),
|
||||||
|
zap.String("request_user_agent", userAgent),
|
||||||
|
zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
|
||||||
|
zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
|
||||||
|
zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
|
||||||
|
zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = wsConn.CloseNow()
|
||||||
|
}()
|
||||||
|
wsConn.SetReadLimit(16 * 1024 * 1024)
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
msgType, firstMessage, err := wsConn.Read(readCtx)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||||
|
reqLog.Warn("openai.websocket_read_first_message_failed",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("client_ip", clientIP),
|
||||||
|
zap.String("close_status", closeStatus),
|
||||||
|
zap.String("close_reason", closeReason),
|
||||||
|
zap.Duration("read_timeout", 30*time.Second),
|
||||||
|
)
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !gjson.ValidBytes(firstMessage) {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
|
||||||
|
if reqModel == "" {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
|
||||||
|
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||||||
|
if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog = reqLog.With(
|
||||||
|
zap.Bool("ws_ingress", true),
|
||||||
|
zap.String("model", reqModel),
|
||||||
|
zap.Bool("has_previous_response_id", previousResponseID != ""),
|
||||||
|
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||||
|
)
|
||||||
|
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||||
|
|
||||||
|
var currentUserRelease func()
|
||||||
|
var currentAccountRelease func()
|
||||||
|
releaseTurnSlots := func() {
|
||||||
|
if currentAccountRelease != nil {
|
||||||
|
currentAccountRelease()
|
||||||
|
currentAccountRelease = nil
|
||||||
|
}
|
||||||
|
if currentUserRelease != nil {
|
||||||
|
currentUserRelease()
|
||||||
|
currentUserRelease = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
|
||||||
|
defer releaseTurnSlots()
|
||||||
|
|
||||||
|
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !userAcquired {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||||
|
|
||||||
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
|
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
|
||||||
|
c,
|
||||||
|
firstMessage,
|
||||||
|
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
||||||
|
)
|
||||||
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||||
|
ctx,
|
||||||
|
apiKey.GroupID,
|
||||||
|
previousResponseID,
|
||||||
|
sessionHash,
|
||||||
|
reqModel,
|
||||||
|
nil,
|
||||||
|
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
account := selection.Account
|
||||||
|
accountMaxConcurrency := account.Concurrency
|
||||||
|
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||||
|
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||||
|
}
|
||||||
|
accountReleaseFunc := selection.ReleaseFunc
|
||||||
|
if !selection.Acquired {
|
||||||
|
if selection.WaitPlan == nil {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||||
|
ctx,
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !fastAcquired {
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
accountReleaseFunc = fastReleaseFunc
|
||||||
|
}
|
||||||
|
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||||
|
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||||
|
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqLog.Debug("openai.websocket_account_selected",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("account_name", account.Name),
|
||||||
|
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||||
|
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks := &service.OpenAIWSIngressHooks{
|
||||||
|
BeforeTurn: func(turn int) error {
|
||||||
|
if turn == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||||
|
releaseTurnSlots()
|
||||||
|
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||||
|
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||||
|
}
|
||||||
|
if !userAcquired {
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||||
|
}
|
||||||
|
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||||
|
if err != nil {
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
userReleaseFunc()
|
||||||
|
}
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||||
|
}
|
||||||
|
if !accountAcquired {
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
userReleaseFunc()
|
||||||
|
}
|
||||||
|
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||||
|
}
|
||||||
|
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||||
|
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||||
|
releaseTurnSlots()
|
||||||
|
if turnErr != nil || result == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
|
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||||
|
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
|
}); err != nil {
|
||||||
|
reqLog.Error("openai.websocket_record_usage_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.String("request_id", result.RequestID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||||
|
reqLog.Warn("openai.websocket_proxy_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("close_status", closeStatus),
|
||||||
|
zap.String("close_reason", closeReason),
|
||||||
|
)
|
||||||
|
var closeErr *service.OpenAIWSClientCloseError
|
||||||
|
if errors.As(err, &closeErr) {
|
||||||
|
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
||||||
|
recovered := recover()
|
||||||
|
if recovered == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
started := false
|
||||||
|
if streamStarted != nil {
|
||||||
|
started = *streamStarted
|
||||||
|
}
|
||||||
|
wroteFallback := h.ensureForwardErrorResponse(c, started)
|
||||||
|
requestLogger(c, "handler.openai_gateway.responses").Error(
|
||||||
|
"openai.responses_panic_recovered",
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Any("panic", recovered),
|
||||||
|
zap.ByteString("stack", debug.Stack()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||||||
|
missing := h.missingResponsesDependencies()
|
||||||
|
if len(missing) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if reqLog == nil {
|
||||||
|
reqLog = requestLogger(c, "handler.openai_gateway.responses")
|
||||||
|
}
|
||||||
|
reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
|
||||||
|
|
||||||
|
if c != nil && c.Writer != nil && !c.Writer.Written() {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "api_error",
|
||||||
|
"message": "Service temporarily unavailable",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
|
||||||
|
missing := make([]string, 0, 5)
|
||||||
|
if h == nil {
|
||||||
|
return append(missing, "handler")
|
||||||
|
}
|
||||||
|
if h.gatewayService == nil {
|
||||||
|
missing = append(missing, "gatewayService")
|
||||||
|
}
|
||||||
|
if h.billingCacheService == nil {
|
||||||
|
missing = append(missing, "billingCacheService")
|
||||||
|
}
|
||||||
|
if h.apiKeyService == nil {
|
||||||
|
missing = append(missing, "apiKeyService")
|
||||||
|
}
|
||||||
|
if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
|
||||||
|
missing = append(missing, "concurrencyHelper")
|
||||||
|
}
|
||||||
|
return missing
|
||||||
|
}
|
||||||
|
|
||||||
func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||||||
if c == nil || key == "" {
|
if c == nil || key == "" {
|
||||||
return 0, false
|
return 0, false
|
||||||
@@ -444,6 +892,14 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
|||||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
defer func() {
|
||||||
|
if recovered := recover(); recovered != nil {
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "handler.openai_gateway.responses"),
|
||||||
|
zap.Any("panic", recovered),
|
||||||
|
).Error("openai.usage_record_task_panic_recovered")
|
||||||
|
}
|
||||||
|
}()
|
||||||
task(ctx)
|
task(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -515,19 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
|||||||
// Stream already started, send error as SSE event then close
|
// Stream already started, send error as SSE event then close
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if ok {
|
if ok {
|
||||||
// Send error event in OpenAI SSE format with proper JSON marshaling
|
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
|
||||||
errorData := map[string]any{
|
errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
|
||||||
"error": map[string]string{
|
|
||||||
"type": errType,
|
|
||||||
"message": message,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal(errorData)
|
|
||||||
if err != nil {
|
|
||||||
_ = c.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
|
||||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
}
|
}
|
||||||
@@ -549,6 +994,16 @@ func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, stream
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
|
||||||
|
if wroteFallback {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c == nil || c.Writer == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.Writer.Written()
|
||||||
|
}
|
||||||
|
|
||||||
// errorResponse returns OpenAI API format error response
|
// errorResponse returns OpenAI API format error response
|
||||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||||
c.JSON(status, gin.H{
|
c.JSON(status, gin.H{
|
||||||
@@ -558,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setOpenAIClientTransportHTTP(c *gin.Context) {
|
||||||
|
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setOpenAIClientTransportWS(c *gin.Context) {
|
||||||
|
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||||
|
gid := int64(0)
|
||||||
|
if groupID != nil {
|
||||||
|
gid = *groupID
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
||||||
|
if r == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
|
||||||
|
if conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reason = strings.TrimSpace(reason)
|
||||||
|
if len(reason) > 120 {
|
||||||
|
reason = reason[:120]
|
||||||
|
}
|
||||||
|
_ = conn.Close(status, reason)
|
||||||
|
_ = conn.CloseNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||||||
|
if err == nil {
|
||||||
|
return "-", "-"
|
||||||
|
}
|
||||||
|
statusCode := coderws.CloseStatus(err)
|
||||||
|
if statusCode == -1 {
|
||||||
|
return "-", "-"
|
||||||
|
}
|
||||||
|
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
|
||||||
|
closeReason := "-"
|
||||||
|
var closeErr coderws.CloseError
|
||||||
|
if errors.As(err, &closeErr) {
|
||||||
|
reason := strings.TrimSpace(closeErr.Reason)
|
||||||
|
if reason != "" {
|
||||||
|
closeReason = reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return closeStatus, closeReason
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +1,19 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
coderws "github.com/coder/websocket"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -105,6 +112,27 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
|||||||
assert.Equal(t, "test error", errorObj["message"])
|
assert.Equal(t, "test error", errorObj["message"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc(t *testing.T) {
|
||||||
|
payload := `{"model":"gpt-5","input":"hello"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload))
|
||||||
|
req.ContentLength = int64(len(payload))
|
||||||
|
|
||||||
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, payload, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8)))
|
||||||
|
req.Body = http.MaxBytesReader(rec, req.Body, 4)
|
||||||
|
|
||||||
|
_, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
|
||||||
|
require.Error(t, err)
|
||||||
|
var maxErr *http.MaxBytesError
|
||||||
|
require.ErrorAs(t, err, &maxErr)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -141,6 +169,387 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test
|
|||||||
assert.Equal(t, "already written", w.Body.String())
|
assert.Equal(t, "already written", w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
t.Run("fallback_written_should_not_downgrade", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context_nil_should_not_downgrade", func(t *testing.T) {
|
||||||
|
require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("response_not_written_should_not_downgrade", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("response_already_written_should_downgrade", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.String(http.StatusForbidden, "already written")
|
||||||
|
require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
streamStarted := false
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
func() {
|
||||||
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
panic("test panic")
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
streamStarted := false
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
func() {
|
||||||
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
|
||||||
|
require.False(t, c.Writer.Written())
|
||||||
|
assert.Equal(t, "", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
c.String(http.StatusTeapot, "already written")
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
streamStarted := false
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
func() {
|
||||||
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
panic("test panic")
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusTeapot, w.Code)
|
||||||
|
assert.Equal(t, "already written", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIMissingResponsesDependencies(t *testing.T) {
|
||||||
|
t.Run("nil_handler", func(t *testing.T) {
|
||||||
|
var h *OpenAIGatewayHandler
|
||||||
|
require.Equal(t, []string{"handler"}, h.missingResponsesDependencies())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all_dependencies_missing", func(t *testing.T) {
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
require.Equal(t,
|
||||||
|
[]string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"},
|
||||||
|
h.missingResponsesDependencies(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all_dependencies_present", func(t *testing.T) {
|
||||||
|
h := &OpenAIGatewayHandler{
|
||||||
|
gatewayService: &service.OpenAIGatewayService{},
|
||||||
|
billingCacheService: &service.BillingCacheService{},
|
||||||
|
apiKeyService: &service.APIKeyService{},
|
||||||
|
concurrencyHelper: &ConcurrencyHelper{
|
||||||
|
concurrencyService: &service.ConcurrencyService{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Empty(t, h.missingResponsesDependencies())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
|
||||||
|
t.Run("missing_dependencies_returns_503", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
ok := h.ensureResponsesDependencies(c, nil)
|
||||||
|
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
errorObj, exists := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, exists)
|
||||||
|
assert.Equal(t, "api_error", errorObj["type"])
|
||||||
|
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("already_written_response_not_overridden", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
c.String(http.StatusTeapot, "already written")
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
ok := h.ensureResponsesDependencies(c, nil)
|
||||||
|
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Equal(t, http.StatusTeapot, w.Code)
|
||||||
|
assert.Equal(t, "already written", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{
|
||||||
|
gatewayService: &service.OpenAIGatewayService{},
|
||||||
|
billingCacheService: &service.BillingCacheService{},
|
||||||
|
apiKeyService: &service.APIKeyService{},
|
||||||
|
concurrencyHelper: &ConcurrencyHelper{
|
||||||
|
concurrencyService: &service.ConcurrencyService{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ok := h.ensureResponsesDependencies(c, nil)
|
||||||
|
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, c.Writer.Written())
|
||||||
|
assert.Equal(t, "", w.Body.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
groupID := int64(2)
|
||||||
|
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
|
||||||
|
ID: 10,
|
||||||
|
GroupID: &groupID,
|
||||||
|
})
|
||||||
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||||
|
UserID: 1,
|
||||||
|
Concurrency: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 故意使用未初始化依赖,验证快速失败而不是崩溃。
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
h.Responses(c)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "api_error", errorObj["type"])
|
||||||
|
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.Responses(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
|
||||||
|
`{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`,
|
||||||
|
))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
groupID := int64(2)
|
||||||
|
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
|
||||||
|
ID: 101,
|
||||||
|
GroupID: &groupID,
|
||||||
|
User: &service.User{ID: 1},
|
||||||
|
})
|
||||||
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||||
|
UserID: 1,
|
||||||
|
Concurrency: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
|
||||||
|
h.Responses(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
|
||||||
|
c.Request.Header.Set("Upgrade", "websocket")
|
||||||
|
c.Request.Header.Set("Connection", "Upgrade")
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.ResponsesWebSocket(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
h.ResponsesWebSocket(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusUpgradeRequired, w.Code)
|
||||||
|
require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
|
||||||
|
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
|
||||||
|
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
|
||||||
|
))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, _, err = clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.Error(t, err)
|
||||||
|
var closeErr coderws.CloseError
|
||||||
|
require.ErrorAs(t, err, &closeErr)
|
||||||
|
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||||
|
require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
cache := &concurrencyCacheMock{
|
||||||
|
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return false, errors.New("user slot unavailable")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache)
|
||||||
|
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||||
|
defer wsServer.Close()
|
||||||
|
|
||||||
|
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||||
|
cancelDial()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.CloseNow()
|
||||||
|
}()
|
||||||
|
|
||||||
|
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
|
||||||
|
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`,
|
||||||
|
))
|
||||||
|
cancelWrite()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
_, _, err = clientConn.Read(readCtx)
|
||||||
|
cancelRead()
|
||||||
|
require.Error(t, err)
|
||||||
|
var closeErr coderws.CloseError
|
||||||
|
require.ErrorAs(t, err, &closeErr)
|
||||||
|
require.Equal(t, coderws.StatusInternalError, closeErr.Code)
|
||||||
|
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
setOpenAIClientTransportHTTP(c)
|
||||||
|
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetOpenAIClientTransportWS(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
setOpenAIClientTransportWS(c)
|
||||||
|
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
|
||||||
|
}
|
||||||
|
|
||||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -228,3 +637,41 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
|||||||
require.NoError(t, setErr)
|
require.NoError(t, setErr)
|
||||||
require.True(t, gjson.ValidBytes(result))
|
require.True(t, gjson.ValidBytes(result))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler {
|
||||||
|
t.Helper()
|
||||||
|
if cache == nil {
|
||||||
|
cache = &concurrencyCacheMock{
|
||||||
|
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &OpenAIGatewayHandler{
|
||||||
|
gatewayService: &service.OpenAIGatewayService{},
|
||||||
|
billingCacheService: &service.BillingCacheService{},
|
||||||
|
apiKeyService: &service.APIKeyService{},
|
||||||
|
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
groupID := int64(2)
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 101,
|
||||||
|
GroupID: &groupID,
|
||||||
|
User: &service.User{ID: subject.UserID},
|
||||||
|
}
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(func(c *gin.Context) {
|
||||||
|
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||||
|
c.Set(string(middleware.ContextKeyUser), subject)
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||||
|
return httptest.NewServer(router)
|
||||||
|
}
|
||||||
|
|||||||
@@ -311,6 +311,35 @@ type opsCaptureWriter struct {
|
|||||||
buf bytes.Buffer
|
buf bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const opsCaptureWriterLimit = 64 * 1024
|
||||||
|
|
||||||
|
var opsCaptureWriterPool = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return &opsCaptureWriter{limit: opsCaptureWriterLimit}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter {
|
||||||
|
w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter)
|
||||||
|
if !ok || w == nil {
|
||||||
|
w = &opsCaptureWriter{}
|
||||||
|
}
|
||||||
|
w.ResponseWriter = rw
|
||||||
|
w.limit = opsCaptureWriterLimit
|
||||||
|
w.buf.Reset()
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseOpsCaptureWriter(w *opsCaptureWriter) {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.ResponseWriter = nil
|
||||||
|
w.limit = opsCaptureWriterLimit
|
||||||
|
w.buf.Reset()
|
||||||
|
opsCaptureWriterPool.Put(w)
|
||||||
|
}
|
||||||
|
|
||||||
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
|
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
|
||||||
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
|
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
|
||||||
remaining := w.limit - w.buf.Len()
|
remaining := w.limit - w.buf.Len()
|
||||||
@@ -342,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) {
|
|||||||
// - Streaming errors after the response has started (SSE) may still need explicit logging.
|
// - Streaming errors after the response has started (SSE) may still need explicit logging.
|
||||||
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
|
originalWriter := c.Writer
|
||||||
|
w := acquireOpsCaptureWriter(originalWriter)
|
||||||
|
defer func() {
|
||||||
|
// Restore the original writer before returning so outer middlewares
|
||||||
|
// don't observe a pooled wrapper that has been released.
|
||||||
|
if c.Writer == w {
|
||||||
|
c.Writer = originalWriter
|
||||||
|
}
|
||||||
|
releaseOpsCaptureWriter(w)
|
||||||
|
}()
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -173,3 +174,43 @@ func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
|
|||||||
enqueueOpsErrorLog(ops, entry)
|
enqueueOpsErrorLog(ops, entry)
|
||||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
|
||||||
|
writer := acquireOpsCaptureWriter(c.Writer)
|
||||||
|
require.NotNil(t, writer)
|
||||||
|
_, err := writer.buf.WriteString("temp-error-body")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
releaseOpsCaptureWriter(writer)
|
||||||
|
|
||||||
|
reused := acquireOpsCaptureWriter(c.Writer)
|
||||||
|
defer releaseOpsCaptureWriter(reused)
|
||||||
|
|
||||||
|
require.Zero(t, reused.buf.Len(), "writer should be reset before reuse")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(middleware2.Recovery())
|
||||||
|
r.Use(middleware2.RequestLogger())
|
||||||
|
r.Use(middleware2.Logger())
|
||||||
|
r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
|
||||||
|
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
})
|
||||||
|
require.Equal(t, http.StatusNoContent, rec.Code)
|
||||||
|
}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
|
SoraClientEnabled: settings.SoraClientEnabled,
|
||||||
Version: h.version,
|
Version: h.version,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
979
backend/internal/handler/sora_client_handler.go
Normal file
979
backend/internal/handler/sora_client_handler.go
Normal file
@@ -0,0 +1,979 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 上游模型缓存 TTL
|
||||||
|
modelCacheTTL = 1 * time.Hour // 上游获取成功
|
||||||
|
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
|
||||||
|
)
|
||||||
|
|
||||||
|
// SoraClientHandler 处理 Sora 客户端 API 请求。
|
||||||
|
type SoraClientHandler struct {
|
||||||
|
genService *service.SoraGenerationService
|
||||||
|
quotaService *service.SoraQuotaService
|
||||||
|
s3Storage *service.SoraS3Storage
|
||||||
|
soraGatewayService *service.SoraGatewayService
|
||||||
|
gatewayService *service.GatewayService
|
||||||
|
mediaStorage *service.SoraMediaStorage
|
||||||
|
apiKeyService *service.APIKeyService
|
||||||
|
|
||||||
|
// 上游模型缓存
|
||||||
|
modelCacheMu sync.RWMutex
|
||||||
|
cachedFamilies []service.SoraModelFamily
|
||||||
|
modelCacheTime time.Time
|
||||||
|
modelCacheUpstream bool // 是否来自上游(决定 TTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSoraClientHandler 创建 Sora 客户端 Handler。
|
||||||
|
func NewSoraClientHandler(
|
||||||
|
genService *service.SoraGenerationService,
|
||||||
|
quotaService *service.SoraQuotaService,
|
||||||
|
s3Storage *service.SoraS3Storage,
|
||||||
|
soraGatewayService *service.SoraGatewayService,
|
||||||
|
gatewayService *service.GatewayService,
|
||||||
|
mediaStorage *service.SoraMediaStorage,
|
||||||
|
apiKeyService *service.APIKeyService,
|
||||||
|
) *SoraClientHandler {
|
||||||
|
return &SoraClientHandler{
|
||||||
|
genService: genService,
|
||||||
|
quotaService: quotaService,
|
||||||
|
s3Storage: s3Storage,
|
||||||
|
soraGatewayService: soraGatewayService,
|
||||||
|
gatewayService: gatewayService,
|
||||||
|
mediaStorage: mediaStorage,
|
||||||
|
apiKeyService: apiKeyService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRequest 生成请求。
|
||||||
|
type GenerateRequest struct {
|
||||||
|
Model string `json:"model" binding:"required"`
|
||||||
|
Prompt string `json:"prompt" binding:"required"`
|
||||||
|
MediaType string `json:"media_type"` // video / image,默认 video
|
||||||
|
VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
|
||||||
|
ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
|
||||||
|
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate 异步生成 — 创建 pending 记录后立即返回。
|
||||||
|
// POST /api/v1/sora/generate
|
||||||
|
func (h *SoraClientHandler) Generate(c *gin.Context) {
|
||||||
|
userID := getUserIDFromContext(c)
|
||||||
|
if userID == 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req GenerateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MediaType == "" {
|
||||||
|
req.MediaType = "video"
|
||||||
|
}
|
||||||
|
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
|
||||||
|
|
||||||
|
// 并发数检查(最多 3 个)
|
||||||
|
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if activeCount >= 3 {
|
||||||
|
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 配额检查(粗略检查,实际文件大小在上传后才知道)
|
||||||
|
if h.quotaService != nil {
|
||||||
|
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
|
||||||
|
var quotaErr *service.QuotaExceededError
|
||||||
|
if errors.As(err, "aErr) {
|
||||||
|
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Error(c, http.StatusForbidden, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 API Key ID 和 Group ID
|
||||||
|
var apiKeyID *int64
|
||||||
|
var groupID *int64
|
||||||
|
|
||||||
|
if req.APIKeyID != nil && h.apiKeyService != nil {
|
||||||
|
// 前端传递了 api_key_id,需要校验
|
||||||
|
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, "API Key 不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if apiKey.UserID != userID {
|
||||||
|
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if apiKey.Status != service.StatusAPIKeyActive {
|
||||||
|
response.Error(c, http.StatusForbidden, "API Key 不可用")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
apiKeyID = &apiKey.ID
|
||||||
|
groupID = apiKey.GroupID
|
||||||
|
} else if id, ok := c.Get("api_key_id"); ok {
|
||||||
|
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
|
||||||
|
if v, ok := id.(int64); ok {
|
||||||
|
apiKeyID = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
|
||||||
|
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动后台异步生成 goroutine
|
||||||
|
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"generation_id": gen.ID,
|
||||||
|
"status": gen.Status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// processGeneration 后台异步执行 Sora 生成任务。
|
||||||
|
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
|
||||||
|
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 标记为生成中
|
||||||
|
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
|
||||||
|
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"handler.sora_client",
|
||||||
|
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
|
||||||
|
genID,
|
||||||
|
userID,
|
||||||
|
groupIDForLog(groupID),
|
||||||
|
model,
|
||||||
|
mediaType,
|
||||||
|
videoCount,
|
||||||
|
strings.TrimSpace(imageInput) != "",
|
||||||
|
len(strings.TrimSpace(prompt)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
|
||||||
|
if groupID == nil {
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.gatewayService == nil {
|
||||||
|
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 选择 Sora 账号
|
||||||
|
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"handler.sora_client",
|
||||||
|
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
|
||||||
|
genID,
|
||||||
|
userID,
|
||||||
|
groupIDForLog(groupID),
|
||||||
|
model,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"handler.sora_client",
|
||||||
|
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
|
||||||
|
genID,
|
||||||
|
userID,
|
||||||
|
groupIDForLog(groupID),
|
||||||
|
model,
|
||||||
|
account.ID,
|
||||||
|
account.Name,
|
||||||
|
account.Platform,
|
||||||
|
account.Type,
|
||||||
|
)
|
||||||
|
|
||||||
|
// 构建 chat completions 请求体(非流式)
|
||||||
|
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
|
||||||
|
|
||||||
|
if h.soraGatewayService == nil {
|
||||||
|
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
mockGinCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
|
||||||
|
|
||||||
|
// 调用 Forward(非流式)
|
||||||
|
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"handler.sora_client",
|
||||||
|
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
|
||||||
|
genID,
|
||||||
|
account.ID,
|
||||||
|
model,
|
||||||
|
recorder.Code,
|
||||||
|
trimForLog(recorder.Body.String(), 400),
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
// 检查是否已取消
|
||||||
|
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||||
|
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
|
||||||
|
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
|
||||||
|
if mediaURL == "" {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"handler.sora_client",
|
||||||
|
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
|
||||||
|
genID,
|
||||||
|
account.ID,
|
||||||
|
model,
|
||||||
|
recorder.Code,
|
||||||
|
trimForLog(recorder.Body.String(), 400),
|
||||||
|
)
|
||||||
|
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查任务是否已被取消
|
||||||
|
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||||
|
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 三层降级存储:S3 → 本地 → 上游临时 URL
|
||||||
|
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
|
||||||
|
|
||||||
|
usageAdded := false
|
||||||
|
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
|
||||||
|
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
|
||||||
|
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||||
|
var quotaErr *service.QuotaExceededError
|
||||||
|
if errors.As(err, "aErr) {
|
||||||
|
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
usageAdded = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
|
||||||
|
gen, _ = h.genService.GetByID(ctx, genID, userID)
|
||||||
|
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
|
||||||
|
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||||
|
if usageAdded && h.quotaService != nil {
|
||||||
|
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 标记完成
|
||||||
|
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
|
||||||
|
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||||
|
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||||
|
if usageAdded && h.quotaService != nil {
|
||||||
|
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
|
||||||
|
func (h *SoraClientHandler) storeMediaWithDegradation(
|
||||||
|
ctx context.Context, userID int64, mediaType string,
|
||||||
|
mediaURL string, mediaURLs []string,
|
||||||
|
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
|
||||||
|
urls := mediaURLs
|
||||||
|
if len(urls) == 0 {
|
||||||
|
urls = []string{mediaURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一层:尝试 S3
|
||||||
|
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
|
||||||
|
keys := make([]string, 0, len(urls))
|
||||||
|
var totalSize int64
|
||||||
|
allOK := true
|
||||||
|
for _, u := range urls {
|
||||||
|
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
|
||||||
|
allOK = false
|
||||||
|
// 清理已上传的文件
|
||||||
|
if len(keys) > 0 {
|
||||||
|
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
keys = append(keys, key)
|
||||||
|
totalSize += size
|
||||||
|
}
|
||||||
|
if allOK && len(keys) > 0 {
|
||||||
|
accessURLs := make([]string, 0, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
|
||||||
|
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||||
|
allOK = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
accessURLs = append(accessURLs, accessURL)
|
||||||
|
}
|
||||||
|
if allOK && len(accessURLs) > 0 {
|
||||||
|
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第二层:尝试本地存储
|
||||||
|
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
|
||||||
|
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
|
||||||
|
if err == nil && len(storedPaths) > 0 {
|
||||||
|
firstPath := storedPaths[0]
|
||||||
|
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
|
||||||
|
if sizeErr != nil {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
|
||||||
|
}
|
||||||
|
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第三层:保留上游临时 URL
|
||||||
|
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
|
||||||
|
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
|
||||||
|
body := map[string]any{
|
||||||
|
"model": model,
|
||||||
|
"messages": []map[string]string{
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
},
|
||||||
|
"stream": false,
|
||||||
|
}
|
||||||
|
if imageInput != "" {
|
||||||
|
body["image_input"] = imageInput
|
||||||
|
}
|
||||||
|
if videoCount > 1 {
|
||||||
|
body["video_count"] = videoCount
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(body)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeVideoCount(mediaType string, videoCount int) int {
|
||||||
|
if mediaType != "video" {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if videoCount <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if videoCount > 3 {
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
return videoCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
|
||||||
|
// OAuth 路径:ForwardResult.MediaURL 已填充。
|
||||||
|
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
|
||||||
|
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
|
||||||
|
// 优先从 ForwardResult 获取(OAuth 路径)
|
||||||
|
if result != nil && result.MediaURL != "" {
|
||||||
|
// 尝试从响应体获取完整 URL 列表
|
||||||
|
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||||
|
return urls[0], urls
|
||||||
|
}
|
||||||
|
return result.MediaURL, []string{result.MediaURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从响应体解析(APIKey 路径)
|
||||||
|
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||||
|
return urls[0], urls
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
|
||||||
|
func parseMediaURLsFromBody(body []byte) []string {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先 media_urls(多图数组)
|
||||||
|
if rawURLs, ok := resp["media_urls"]; ok {
|
||||||
|
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
|
||||||
|
urls := make([]string, 0, len(arr))
|
||||||
|
for _, item := range arr {
|
||||||
|
if s, ok := item.(string); ok && s != "" {
|
||||||
|
urls = append(urls, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(urls) > 0 {
|
||||||
|
return urls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 回退到 media_url(单个 URL)
|
||||||
|
if url, ok := resp["media_url"].(string); ok && url != "" {
|
||||||
|
return []string{url}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListGenerations 查询生成记录列表。
|
||||||
|
// GET /api/v1/sora/generations
|
||||||
|
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
|
||||||
|
userID := getUserIDFromContext(c)
|
||||||
|
if userID == 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
params := service.SoraGenerationListParams{
|
||||||
|
UserID: userID,
|
||||||
|
Status: c.Query("status"),
|
||||||
|
StorageType: c.Query("storage_type"),
|
||||||
|
MediaType: c.Query("media_type"),
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
}
|
||||||
|
|
||||||
|
gens, total, err := h.genService.List(c.Request.Context(), params)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 为 S3 记录动态生成预签名 URL
|
||||||
|
for _, gen := range gens {
|
||||||
|
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"data": gens,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeneration 查询生成记录详情。
|
||||||
|
// GET /api/v1/sora/generations/:id
|
||||||
|
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
|
||||||
|
userID := getUserIDFromContext(c)
|
||||||
|
if userID == 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusNotFound, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||||
|
response.Success(c, gen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteGeneration 删除生成记录。
|
||||||
|
// DELETE /api/v1/sora/generations/:id
|
||||||
|
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
|
||||||
|
userID := getUserIDFromContext(c)
|
||||||
|
if userID == 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusNotFound, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
|
||||||
|
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
|
||||||
|
paths := gen.MediaURLs
|
||||||
|
if len(paths) == 0 && gen.MediaURL != "" {
|
||||||
|
paths = []string{gen.MediaURL}
|
||||||
|
}
|
||||||
|
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
|
||||||
|
response.Error(c, http.StatusNotFound, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "已删除"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuota 查询用户存储配额。
|
||||||
|
// GET /api/v1/sora/quota
|
||||||
|
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
|
||||||
|
userID := getUserIDFromContext(c)
|
||||||
|
if userID == 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.quotaService == nil {
|
||||||
|
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelGeneration 取消生成任务。
|
||||||
|
// POST /api/v1/sora/generations/:id/cancel
|
||||||
|
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
|
||||||
|
userID := getUserIDFromContext(c)
|
||||||
|
if userID == 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 权限校验
|
||||||
|
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusNotFound, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = gen
|
||||||
|
|
||||||
|
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
|
||||||
|
if errors.Is(err, service.ErrSoraGenerationNotActive) {
|
||||||
|
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "已取消"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveToStorage 手动保存 upstream 记录到 S3。
|
||||||
|
// POST /api/v1/sora/generations/:id/save
|
||||||
|
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
|
||||||
|
userID := getUserIDFromContext(c)
|
||||||
|
if userID == 0 {
|
||||||
|
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, http.StatusNotFound, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if gen.StorageType != service.SoraStorageTypeUpstream {
|
||||||
|
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gen.MediaURL == "" {
|
||||||
|
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
|
||||||
|
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceURLs := gen.MediaURLs
|
||||||
|
if len(sourceURLs) == 0 && gen.MediaURL != "" {
|
||||||
|
sourceURLs = []string{gen.MediaURL}
|
||||||
|
}
|
||||||
|
if len(sourceURLs) == 0 {
|
||||||
|
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadedKeys := make([]string, 0, len(sourceURLs))
|
||||||
|
accessURLs := make([]string, 0, len(sourceURLs))
|
||||||
|
var totalSize int64
|
||||||
|
|
||||||
|
for _, sourceURL := range sourceURLs {
|
||||||
|
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
|
||||||
|
if uploadErr != nil {
|
||||||
|
if len(uploadedKeys) > 0 {
|
||||||
|
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||||
|
}
|
||||||
|
var upstreamErr *service.UpstreamDownloadError
|
||||||
|
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
|
||||||
|
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
|
||||||
|
if err != nil {
|
||||||
|
uploadedKeys = append(uploadedKeys, objectKey)
|
||||||
|
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||||
|
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
uploadedKeys = append(uploadedKeys, objectKey)
|
||||||
|
accessURLs = append(accessURLs, accessURL)
|
||||||
|
totalSize += fileSize
|
||||||
|
}
|
||||||
|
|
||||||
|
usageAdded := false
|
||||||
|
if totalSize > 0 && h.quotaService != nil {
|
||||||
|
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
|
||||||
|
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||||
|
var quotaErr *service.QuotaExceededError
|
||||||
|
if errors.As(err, "aErr) {
|
||||||
|
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
usageAdded = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.genService.UpdateStorageForCompleted(
|
||||||
|
c.Request.Context(),
|
||||||
|
id,
|
||||||
|
accessURLs[0],
|
||||||
|
accessURLs,
|
||||||
|
service.SoraStorageTypeS3,
|
||||||
|
uploadedKeys,
|
||||||
|
totalSize,
|
||||||
|
); err != nil {
|
||||||
|
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||||
|
if usageAdded && h.quotaService != nil {
|
||||||
|
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
|
||||||
|
}
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"message": "已保存到 S3",
|
||||||
|
"object_key": uploadedKeys[0],
|
||||||
|
"object_keys": uploadedKeys,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStorageStatus 返回存储状态。
|
||||||
|
// GET /api/v1/sora/storage-status
|
||||||
|
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
|
||||||
|
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
|
||||||
|
s3Healthy := false
|
||||||
|
if s3Enabled {
|
||||||
|
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
|
||||||
|
}
|
||||||
|
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"s3_enabled": s3Enabled,
|
||||||
|
"s3_healthy": s3Healthy,
|
||||||
|
"local_enabled": localEnabled,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
|
||||||
|
switch storageType {
|
||||||
|
case service.SoraStorageTypeS3:
|
||||||
|
if h.s3Storage != nil && len(s3Keys) > 0 {
|
||||||
|
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case service.SoraStorageTypeLocal:
|
||||||
|
if h.mediaStorage != nil && len(localPaths) > 0 {
|
||||||
|
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
|
||||||
|
func getUserIDFromContext(c *gin.Context) int64 {
|
||||||
|
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
|
||||||
|
return subject.UserID
|
||||||
|
}
|
||||||
|
|
||||||
|
if id, ok := c.Get("user_id"); ok {
|
||||||
|
switch v := id.(type) {
|
||||||
|
case int64:
|
||||||
|
return v
|
||||||
|
case float64:
|
||||||
|
return int64(v)
|
||||||
|
case string:
|
||||||
|
n, _ := strconv.ParseInt(v, 10, 64)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 尝试从 JWT claims 获取
|
||||||
|
if id, ok := c.Get("userID"); ok {
|
||||||
|
if v, ok := id.(int64); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func groupIDForLog(groupID *int64) int64 {
|
||||||
|
if groupID == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *groupID
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimForLog(raw string, maxLen int) string {
|
||||||
|
trimmed := strings.TrimSpace(raw)
|
||||||
|
if maxLen <= 0 || len(trimmed) <= maxLen {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
return trimmed[:maxLen] + "...(truncated)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModels 获取可用 Sora 模型家族列表。
|
||||||
|
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
|
||||||
|
// GET /api/v1/sora/models
|
||||||
|
func (h *SoraClientHandler) GetModels(c *gin.Context) {
|
||||||
|
families := h.getModelFamilies(c.Request.Context())
|
||||||
|
response.Success(c, families)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getModelFamilies 获取模型家族列表(带缓存)。
|
||||||
|
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
|
||||||
|
// 读锁检查缓存
|
||||||
|
h.modelCacheMu.RLock()
|
||||||
|
ttl := modelCacheTTL
|
||||||
|
if !h.modelCacheUpstream {
|
||||||
|
ttl = modelCacheFailedTTL
|
||||||
|
}
|
||||||
|
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||||
|
families := h.cachedFamilies
|
||||||
|
h.modelCacheMu.RUnlock()
|
||||||
|
return families
|
||||||
|
}
|
||||||
|
h.modelCacheMu.RUnlock()
|
||||||
|
|
||||||
|
// 写锁更新缓存
|
||||||
|
h.modelCacheMu.Lock()
|
||||||
|
defer h.modelCacheMu.Unlock()
|
||||||
|
|
||||||
|
// double-check
|
||||||
|
ttl = modelCacheTTL
|
||||||
|
if !h.modelCacheUpstream {
|
||||||
|
ttl = modelCacheFailedTTL
|
||||||
|
}
|
||||||
|
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||||
|
return h.cachedFamilies
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试从上游获取
|
||||||
|
families, err := h.fetchUpstreamModels(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
|
||||||
|
families = service.BuildSoraModelFamilies()
|
||||||
|
h.cachedFamilies = families
|
||||||
|
h.modelCacheTime = time.Now()
|
||||||
|
h.modelCacheUpstream = false
|
||||||
|
return families
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
|
||||||
|
h.cachedFamilies = families
|
||||||
|
h.modelCacheTime = time.Now()
|
||||||
|
h.modelCacheUpstream = true
|
||||||
|
return families
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
|
||||||
|
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
|
||||||
|
if h.gatewayService == nil {
|
||||||
|
return nil, fmt.Errorf("gatewayService 未初始化")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置 ForcePlatform 用于 Sora 账号选择
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||||
|
|
||||||
|
// 选择一个 Sora 账号
|
||||||
|
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 仅支持 API Key 类型账号
|
||||||
|
if account.Type != service.AccountTypeAPIKey {
|
||||||
|
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := account.GetCredential("api_key")
|
||||||
|
if apiKey == "" {
|
||||||
|
return nil, fmt.Errorf("账号缺少 api_key")
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := account.GetBaseURL()
|
||||||
|
if baseURL == "" {
|
||||||
|
return nil, fmt.Errorf("账号缺少 base_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建上游模型列表请求
|
||||||
|
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
|
||||||
|
|
||||||
|
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("请求上游失败: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 OpenAI 格式的模型列表
|
||||||
|
var modelsResp struct {
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelsResp.Data) == 0 {
|
||||||
|
return nil, fmt.Errorf("上游返回空模型列表")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取模型 ID
|
||||||
|
modelIDs := make([]string, 0, len(modelsResp.Data))
|
||||||
|
for _, m := range modelsResp.Data {
|
||||||
|
modelIDs = append(modelIDs, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为模型家族
|
||||||
|
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
|
||||||
|
if len(families) == 0 {
|
||||||
|
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
|
||||||
|
}
|
||||||
|
|
||||||
|
return families, nil
|
||||||
|
}
|
||||||
3135
backend/internal/handler/sora_client_handler_test.go
Normal file
3135
backend/internal/handler/sora_client_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
@@ -17,6 +16,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
@@ -107,7 +107,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
zap.Any("group_id", apiKey.GroupID),
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
)
|
)
|
||||||
|
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
@@ -461,6 +461,14 @@ func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask)
|
|||||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
defer func() {
|
||||||
|
if recovered := recover(); recovered != nil {
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||||
|
zap.Any("panic", recovered),
|
||||||
|
).Error("sora.usage_record_task_panic_recovered")
|
||||||
|
}
|
||||||
|
}()
|
||||||
task(ctx)
|
task(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -314,10 +314,10 @@ func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID i
|
|||||||
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
@@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
// Parse additional filters
|
// Parse additional filters
|
||||||
model := c.Query("model")
|
model := c.Query("model")
|
||||||
|
|
||||||
|
var requestType *int16
|
||||||
var stream *bool
|
var stream *bool
|
||||||
if streamStr := c.Query("stream"); streamStr != "" {
|
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||||
|
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := int16(parsed)
|
||||||
|
requestType = &value
|
||||||
|
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||||
val, err := strconv.ParseBool(streamStr)
|
val, err := strconv.ParseBool(streamStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||||
@@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
|||||||
UserID: subject.UserID, // Always filter by current user for security
|
UserID: subject.UserID, // Always filter by current user for security
|
||||||
APIKeyID: apiKeyID,
|
APIKeyID: apiKeyID,
|
||||||
Model: model,
|
Model: model,
|
||||||
|
RequestType: requestType,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
StartTime: startTime,
|
StartTime: startTime,
|
||||||
|
|||||||
80
backend/internal/handler/usage_handler_request_type_test.go
Normal file
80
backend/internal/handler/usage_handler_request_type_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userUsageRepoCapture struct {
|
||||||
|
service.UsageLogRepository
|
||||||
|
listFilters usagestats.UsageLogFilters
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
s.listFilters = filters
|
||||||
|
return []service.UsageLog{}, &pagination.PaginationResult{
|
||||||
|
Total: 0,
|
||||||
|
Page: params.Page,
|
||||||
|
PageSize: params.PageSize,
|
||||||
|
Pages: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
usageSvc := service.NewUsageService(repo, nil, nil, nil)
|
||||||
|
handler := NewUsageHandler(usageSvc, nil)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(func(c *gin.Context) {
|
||||||
|
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
router.GET("/usage", handler.List)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserUsageListRequestTypePriority(t *testing.T) {
|
||||||
|
repo := &userUsageRepoCapture{}
|
||||||
|
router := newUserUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, int64(42), repo.listFilters.UserID)
|
||||||
|
require.NotNil(t, repo.listFilters.RequestType)
|
||||||
|
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
|
||||||
|
require.Nil(t, repo.listFilters.Stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserUsageListInvalidRequestType(t *testing.T) {
|
||||||
|
repo := &userUsageRepoCapture{}
|
||||||
|
router := newUserUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserUsageListInvalidStream(t *testing.T) {
|
||||||
|
repo := &userUsageRepoCapture{}
|
||||||
|
router := newUserUsageRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
}
|
||||||
@@ -61,6 +61,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
var called atomic.Bool
|
||||||
|
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
panic("usage task panic")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
called.Store(true)
|
||||||
|
})
|
||||||
|
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||||
pool := newUsageRecordTestPool(t)
|
pool := newUsageRecordTestPool(t)
|
||||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||||
@@ -98,6 +114,22 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
var called atomic.Bool
|
||||||
|
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
panic("usage task panic")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
called.Store(true)
|
||||||
|
})
|
||||||
|
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||||
|
}
|
||||||
|
|
||||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||||
pool := newUsageRecordTestPool(t)
|
pool := newUsageRecordTestPool(t)
|
||||||
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
||||||
@@ -134,3 +166,19 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
|||||||
h.submitUsageRecordTask(nil)
|
h.submitUsageRecordTask(nil)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||||
|
h := &SoraGatewayHandler{}
|
||||||
|
var called atomic.Bool
|
||||||
|
|
||||||
|
require.NotPanics(t, func() {
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
panic("usage task panic")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
called.Store(true)
|
||||||
|
})
|
||||||
|
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
|
|||||||
groupHandler *admin.GroupHandler,
|
groupHandler *admin.GroupHandler,
|
||||||
accountHandler *admin.AccountHandler,
|
accountHandler *admin.AccountHandler,
|
||||||
announcementHandler *admin.AnnouncementHandler,
|
announcementHandler *admin.AnnouncementHandler,
|
||||||
|
dataManagementHandler *admin.DataManagementHandler,
|
||||||
oauthHandler *admin.OAuthHandler,
|
oauthHandler *admin.OAuthHandler,
|
||||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||||
@@ -35,6 +36,7 @@ func ProvideAdminHandlers(
|
|||||||
Group: groupHandler,
|
Group: groupHandler,
|
||||||
Account: accountHandler,
|
Account: accountHandler,
|
||||||
Announcement: announcementHandler,
|
Announcement: announcementHandler,
|
||||||
|
DataManagement: dataManagementHandler,
|
||||||
OAuth: oauthHandler,
|
OAuth: oauthHandler,
|
||||||
OpenAIOAuth: openaiOAuthHandler,
|
OpenAIOAuth: openaiOAuthHandler,
|
||||||
GeminiOAuth: geminiOAuthHandler,
|
GeminiOAuth: geminiOAuthHandler,
|
||||||
@@ -75,6 +77,7 @@ func ProvideHandlers(
|
|||||||
gatewayHandler *GatewayHandler,
|
gatewayHandler *GatewayHandler,
|
||||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||||
soraGatewayHandler *SoraGatewayHandler,
|
soraGatewayHandler *SoraGatewayHandler,
|
||||||
|
soraClientHandler *SoraClientHandler,
|
||||||
settingHandler *SettingHandler,
|
settingHandler *SettingHandler,
|
||||||
totpHandler *TotpHandler,
|
totpHandler *TotpHandler,
|
||||||
_ *service.IdempotencyCoordinator,
|
_ *service.IdempotencyCoordinator,
|
||||||
@@ -92,6 +95,7 @@ func ProvideHandlers(
|
|||||||
Gateway: gatewayHandler,
|
Gateway: gatewayHandler,
|
||||||
OpenAIGateway: openaiGatewayHandler,
|
OpenAIGateway: openaiGatewayHandler,
|
||||||
SoraGateway: soraGatewayHandler,
|
SoraGateway: soraGatewayHandler,
|
||||||
|
SoraClient: soraClientHandler,
|
||||||
Setting: settingHandler,
|
Setting: settingHandler,
|
||||||
Totp: totpHandler,
|
Totp: totpHandler,
|
||||||
}
|
}
|
||||||
@@ -119,6 +123,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewGroupHandler,
|
admin.NewGroupHandler,
|
||||||
admin.NewAccountHandler,
|
admin.NewAccountHandler,
|
||||||
admin.NewAnnouncementHandler,
|
admin.NewAnnouncementHandler,
|
||||||
|
admin.NewDataManagementHandler,
|
||||||
admin.NewOAuthHandler,
|
admin.NewOAuthHandler,
|
||||||
admin.NewOpenAIOAuthHandler,
|
admin.NewOpenAIOAuthHandler,
|
||||||
admin.NewGeminiOAuthHandler,
|
admin.NewGeminiOAuthHandler,
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ var claudeModels = []modelDef{
|
|||||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||||
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||||
|
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||||
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +166,8 @@ var geminiModels = []modelDef{
|
|||||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||||
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
|
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||||
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
|
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||||
|
{ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||||
|
{ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||||
}
|
}
|
||||||
|
|||||||
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package antigravity
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
models := DefaultModels()
|
||||||
|
byID := make(map[string]ClaudeModel, len(models))
|
||||||
|
for _, m := range models {
|
||||||
|
byID[m.ID] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
requiredIDs := []string{
|
||||||
|
"claude-opus-4-6-thinking",
|
||||||
|
"gemini-3.1-flash-image",
|
||||||
|
"gemini-3.1-flash-image-preview",
|
||||||
|
"gemini-3-pro-image", // legacy compatibility
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range requiredIDs {
|
||||||
|
if _, ok := byID[id]; !ok {
|
||||||
|
t.Fatalf("expected model %q to be exposed in DefaultModels", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
|
|||||||
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
|
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
|
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
|
||||||
type GeminiImageConfig struct {
|
type GeminiImageConfig struct {
|
||||||
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
|
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||||
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
|
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
|
||||||
|
|||||||
@@ -53,7 +53,8 @@ const (
|
|||||||
var defaultUserAgentVersion = "1.19.6"
|
var defaultUserAgentVersion = "1.19.6"
|
||||||
|
|
||||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||||
var defaultClientSecret = "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"
|
// 默认值使用占位符,生产环境请通过环境变量注入真实值。
|
||||||
|
var defaultClientSecret = "GOCSPX-your-client-secret"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// 从环境变量读取版本号,未设置则使用默认值
|
// 从环境变量读取版本号,未设置则使用默认值
|
||||||
|
|||||||
@@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
|||||||
|
|
||||||
expectedParams := map[string]string{
|
expectedParams := map[string]string{
|
||||||
"client_id": ClientID,
|
"client_id": ClientID,
|
||||||
"redirect_uri": RedirectURI,
|
"redirect_uri": RedirectURI,
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"scope": Scopes,
|
"scope": Scopes,
|
||||||
"state": state,
|
"state": state,
|
||||||
"code_challenge": codeChallenge,
|
"code_challenge": codeChallenge,
|
||||||
"code_challenge_method": "S256",
|
"code_challenge_method": "S256",
|
||||||
"access_type": "offline",
|
"access_type": "offline",
|
||||||
"prompt": "consent",
|
"prompt": "consent",
|
||||||
"include_granted_scopes": "true",
|
"include_granted_scopes": "true",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -684,7 +684,7 @@ func TestConstants_值正确(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
||||||
}
|
}
|
||||||
if secret != "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" {
|
if secret != "GOCSPX-your-client-secret" {
|
||||||
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
||||||
}
|
}
|
||||||
if RedirectURI != "http://localhost:8085/callback" {
|
if RedirectURI != "http://localhost:8085/callback" {
|
||||||
|
|||||||
@@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToHTTP_MetadataDeepCopy(t *testing.T) {
|
||||||
|
md := map[string]string{"k": "v"}
|
||||||
|
appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md)
|
||||||
|
|
||||||
|
code, body := ToHTTP(appErr)
|
||||||
|
require.Equal(t, http.StatusBadRequest, code)
|
||||||
|
require.Equal(t, "v", body.Metadata["k"])
|
||||||
|
|
||||||
|
md["k"] = "changed"
|
||||||
|
require.Equal(t, "v", body.Metadata["k"])
|
||||||
|
|
||||||
|
appErr.Metadata["k"] = "changed-again"
|
||||||
|
require.Equal(t, "v", body.Metadata["k"])
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) {
|
|||||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||||
}
|
}
|
||||||
|
|
||||||
cloned := Clone(appErr)
|
body = Status{
|
||||||
return int(cloned.Code), cloned.Status
|
Code: appErr.Code,
|
||||||
|
Reason: appErr.Reason,
|
||||||
|
Message: appErr.Message,
|
||||||
|
}
|
||||||
|
if appErr.Metadata != nil {
|
||||||
|
body.Metadata = make(map[string]string, len(appErr.Metadata))
|
||||||
|
for k, v := range appErr.Metadata {
|
||||||
|
body.Metadata[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return int(appErr.Code), body
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user